package org.apache.samoa.learners.classifiers.trees;

/*
 * #%L
 * SAMOA
 * %%
 * Copyright (C) 2014 - 2015 Apache Software Foundation
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import com.google.common.collect.ImmutableSet;

import java.util.Set;

import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.AdaptiveLearner;
import org.apache.samoa.learners.ClassificationLearner;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.topology.Stream;
import org.apache.samoa.topology.TopologyBuilder;

import com.github.javacliparser.ClassOption;
import com.github.javacliparser.Configurable;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;

/**
 * Vertical Hoeffding Tree.
 * <p/>
 * Vertical Hoeffding Tree (VHT) classifier is a distributed classifier that utilizes vertical parallelism on top of
 * Very Fast Decision Tree (VFDT) classifier.
 * 
 * @author Arinto Murdopo
 */
public final class VerticalHoeffdingTree implements ClassificationLearner, AdaptiveLearner, Configurable {

  private static final long serialVersionUID = -4937416312929984057L;

  public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
      'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
      "GaussianNumericAttributeClassObserver");

  public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
      'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
      "NominalAttributeClassObserver");

  public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
      's', "Split criterion to use.", SplitCriterion.class,
      "InfoGainSplitCriterion");

  public FloatOption splitConfidenceOption = new FloatOption(
      "splitConfidence",
      'c',
      "The allowable error in split decision, values closer to 0 will take longer to decide.",
      0.0000001, 0.0, 1.0);

  public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
      't', "Threshold below which a split will be forced to break ties.",
      0.05, 0.0, 1.0);

  public IntOption gracePeriodOption = new IntOption(
      "gracePeriod",
      'g',
      "The number of instances a leaf should observe between split attempts.",
      200, 0, Integer.MAX_VALUE);

  public IntOption parallelismHintOption = new IntOption(
      "parallelismHint",
      'p',
      "The number of local statistics PI to do distributed computation",
      1, 1, Integer.MAX_VALUE);

  public IntOption timeOutOption = new IntOption(
      "timeOut",
      'o',
      "The duration to wait all distributed computation results from local statistics PI",
      30, 1, Integer.MAX_VALUE);

  public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
      "Only allow binary splits.");

  private Stream resultStream;

  private FilterProcessor filterProc;

  @Override
  public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {

    this.filterProc = new FilterProcessor.Builder(dataset)
        .build();
    topologyBuilder.addProcessor(filterProc, parallelism);

    Stream filterStream = topologyBuilder.createStream(filterProc);
    this.filterProc.setOutputStream(filterStream);

    ModelAggregatorProcessor modelAggrProc = new ModelAggregatorProcessor.Builder(dataset)
        .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue())
        .splitConfidence(splitConfidenceOption.getValue())
        .tieThreshold(tieThresholdOption.getValue())
        .gracePeriod(gracePeriodOption.getValue())
        .parallelismHint(parallelismHintOption.getValue())
        .timeOut(timeOutOption.getValue())
        .changeDetector(this.getChangeDetector())
        .build();

    topologyBuilder.addProcessor(modelAggrProc, parallelism);

    topologyBuilder.connectInputShuffleStream(filterStream, modelAggrProc);

    this.resultStream = topologyBuilder.createStream(modelAggrProc);
    modelAggrProc.setResultStream(resultStream);

    Stream attributeStream = topologyBuilder.createStream(modelAggrProc);
    modelAggrProc.setAttributeStream(attributeStream);

    Stream controlStream = topologyBuilder.createStream(modelAggrProc);
    modelAggrProc.setControlStream(controlStream);

    LocalStatisticsProcessor locStatProc = new LocalStatisticsProcessor.Builder()
        .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue())
        .binarySplit(binarySplitsOption.isSet())
        .nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue())
        .numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue())
        .build();

    topologyBuilder.addProcessor(locStatProc, parallelismHintOption.getValue());
    topologyBuilder.connectInputKeyStream(attributeStream, locStatProc);
    topologyBuilder.connectInputAllStream(controlStream, locStatProc);

    Stream computeStream = topologyBuilder.createStream(locStatProc);

    locStatProc.setComputationResultStream(computeStream);
    topologyBuilder.connectInputAllStream(computeStream, modelAggrProc);
  }

  @Override
  public Processor getInputProcessor() {
    return this.filterProc;
  }

  @Override
  public Set<Stream> getResultStreams() {
    return ImmutableSet.of(this.resultStream);
  }

  protected ChangeDetector changeDetector;

  @Override
  public ChangeDetector getChangeDetector() {
    return this.changeDetector;
  }

  @Override
  public void setChangeDetector(ChangeDetector cd) {
    this.changeDetector = cd;
  }

  static class LearningNodeIdGenerator {

    // TODO: add code to warn user of when value reaches Long.MAX_VALUES
    private static long id = 0;

    static synchronized long generate() {
      return id++;
    }
  }
}
