package com.yahoo.labs.samoa.learners.classifiers.rules.distributed;

/*
 * #%L
 * SAMOA
 * %%
 * Copyright (C) 2014 - 2015 Apache Software Foundation
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.yahoo.labs.samoa.core.ContentEvent;
import com.yahoo.labs.samoa.core.Processor;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.learners.InstanceContentEvent;
import com.yahoo.labs.samoa.learners.ResultContentEvent;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver;
import com.yahoo.labs.samoa.topology.Stream;

/**
 * Default Rule Learner Processor (HAMR).
 * 
 * @author Anh Thu Vu
 * 
 */
public class AMRDefaultRuleProcessor implements Processor {

  /**
	 * 
	 */
  private static final long serialVersionUID = 23702084591044447L;

  private static final Logger logger =
      LoggerFactory.getLogger(AMRDefaultRuleProcessor.class);

  private int processorId;

  // Default rule
  protected transient ActiveRule defaultRule;
  protected transient int ruleNumberID;
  protected transient double[] statistics;

  // SAMOA Stream
  private Stream ruleStream;
  private Stream resultStream;

  // Options
  protected int pageHinckleyThreshold;
  protected double pageHinckleyAlpha;
  protected boolean driftDetection;
  protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
  protected boolean constantLearningRatioDecay;
  protected double learningRatio;

  protected double splitConfidence;
  protected double tieThreshold;
  protected int gracePeriod;

  protected FIMTDDNumericAttributeClassLimitObserver numericObserver;

  /*
   * Constructor
   */
  public AMRDefaultRuleProcessor(Builder builder) {
    this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
    this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
    this.driftDetection = builder.driftDetection;
    this.predictionFunction = builder.predictionFunction;
    this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
    this.learningRatio = builder.learningRatio;
    this.splitConfidence = builder.splitConfidence;
    this.tieThreshold = builder.tieThreshold;
    this.gracePeriod = builder.gracePeriod;

    this.numericObserver = builder.numericObserver;
  }

  @Override
  public boolean process(ContentEvent event) {
    InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
    // predict
    if (instanceEvent.isTesting()) {
      this.predictOnInstance(instanceEvent);
    }

    // train
    if (instanceEvent.isTraining()) {
      this.trainOnInstance(instanceEvent);
    }

    return false;
  }

  /*
   * Prediction
   */
  private void predictOnInstance(InstanceContentEvent instanceEvent) {
    double[] vote = defaultRule.getPrediction(instanceEvent.getInstance());
    ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
    resultStream.put(rce);
  }

  private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
    ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
        inEvent.getClassId(), prediction, inEvent.isLastEvent());
    rce.setClassifierIndex(this.processorId);
    rce.setEvaluationIndex(inEvent.getEvaluationIndex());
    return rce;
  }

  /*
   * Training
   */
  private void trainOnInstance(InstanceContentEvent instanceEvent) {
    this.trainOnInstanceImpl(instanceEvent.getInstance());
  }

  public void trainOnInstanceImpl(Instance instance) {
    defaultRule.updateStatistics(instance);
    if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
      if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
        ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
            (RuleActiveRegressionNode) defaultRule.getLearningNode(),
            ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch
        defaultRule.split();
        defaultRule.setRuleNumberID(++ruleNumberID);
        // send out the new rule
        sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
        defaultRule = newDefaultRule;
      }
    }
  }

  /*
   * Create new rules
   */
  private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
    ActiveRule r = newRule(ID);

    if (node != null)
    {
      if (node.getPerceptron() != null)
      {
        r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
        r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
      }
      if (statistics == null)
      {
        double mean;
        if (node.getNodeStatistics().getValue(0) > 0) {
          mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
          r.getLearningNode().getTargetMean().reset(mean, 1);
        }
      }
    }
    if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
    {
      double mean;
      if (statistics[0] > 0) {
        mean = statistics[1] / statistics[0];
        ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
      }
    }
    return r;
  }

  private ActiveRule newRule(int ID) {
    ActiveRule r = new ActiveRule.Builder().
        threshold(this.pageHinckleyThreshold).
        alpha(this.pageHinckleyAlpha).
        changeDetection(this.driftDetection).
        predictionFunction(this.predictionFunction).
        statistics(new double[3]).
        learningRatio(this.learningRatio).
        numericObserver(numericObserver).
        id(ID).build();
    return r;
  }

  @Override
  public void onCreate(int id) {
    this.processorId = id;
    this.statistics = new double[] { 0.0, 0, 0 };
    this.ruleNumberID = 0;
    this.defaultRule = newRule(++this.ruleNumberID);
  }

  /*
   * Clone processor
   */
  @Override
  public Processor newProcessor(Processor p) {
    AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p;
    Builder builder = new Builder(oldProcessor);
    AMRDefaultRuleProcessor newProcessor = builder.build();
    newProcessor.resultStream = oldProcessor.resultStream;
    newProcessor.ruleStream = oldProcessor.ruleStream;
    return newProcessor;
  }

  /*
   * Send events
   */
  private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
    RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
    this.ruleStream.put(rce);
  }

  /*
   * Output streams
   */
  public void setRuleStream(Stream ruleStream) {
    this.ruleStream = ruleStream;
  }

  public Stream getRuleStream() {
    return this.ruleStream;
  }

  public void setResultStream(Stream resultStream) {
    this.resultStream = resultStream;
  }

  public Stream getResultStream() {
    return this.resultStream;
  }

  /*
   * Builder
   */
  public static class Builder {
    private int pageHinckleyThreshold;
    private double pageHinckleyAlpha;
    private boolean driftDetection;
    private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
    private boolean constantLearningRatioDecay;
    private double learningRatio;
    private double splitConfidence;
    private double tieThreshold;
    private int gracePeriod;

    private FIMTDDNumericAttributeClassLimitObserver numericObserver;

    private Instances dataset;

    public Builder(Instances dataset) {
      this.dataset = dataset;
    }

    public Builder(AMRDefaultRuleProcessor processor) {
      this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
      this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
      this.driftDetection = processor.driftDetection;
      this.predictionFunction = processor.predictionFunction;
      this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
      this.learningRatio = processor.learningRatio;
      this.splitConfidence = processor.splitConfidence;
      this.tieThreshold = processor.tieThreshold;
      this.gracePeriod = processor.gracePeriod;

      this.numericObserver = processor.numericObserver;
    }

    public Builder threshold(int threshold) {
      this.pageHinckleyThreshold = threshold;
      return this;
    }

    public Builder alpha(double alpha) {
      this.pageHinckleyAlpha = alpha;
      return this;
    }

    public Builder changeDetection(boolean changeDetection) {
      this.driftDetection = changeDetection;
      return this;
    }

    public Builder predictionFunction(int predictionFunction) {
      this.predictionFunction = predictionFunction;
      return this;
    }

    public Builder constantLearningRatioDecay(boolean constantDecay) {
      this.constantLearningRatioDecay = constantDecay;
      return this;
    }

    public Builder learningRatio(double learningRatio) {
      this.learningRatio = learningRatio;
      return this;
    }

    public Builder splitConfidence(double splitConfidence) {
      this.splitConfidence = splitConfidence;
      return this;
    }

    public Builder tieThreshold(double tieThreshold) {
      this.tieThreshold = tieThreshold;
      return this;
    }

    public Builder gracePeriod(int gracePeriod) {
      this.gracePeriod = gracePeriod;
      return this;
    }

    public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
      this.numericObserver = numericObserver;
      return this;
    }

    public AMRDefaultRuleProcessor build() {
      return new AMRDefaultRuleProcessor(this);
    }
  }

}
