package com.yahoo.labs.samoa.learners.classifiers.rules.distributed;

/*
 * #%L
 * SAMOA
 * %%
 * Copyright (C) 2014 - 2015 Apache Software Foundation
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.yahoo.labs.samoa.core.ContentEvent;
import com.yahoo.labs.samoa.core.Processor;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.learners.InstanceContentEvent;
import com.yahoo.labs.samoa.learners.ResultContentEvent;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.LearningRule;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.PassiveRule;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote;
import com.yahoo.labs.samoa.topology.Stream;

/**
 * Model Aggregator Processor (VAMR).
 * 
 * @author Anh Thu Vu
 * 
 */
public class AMRulesAggregatorProcessor implements Processor {

  /**
	 * 
	 */
  private static final long serialVersionUID = 6303385725332704251L;

  private static final Logger logger =
      LoggerFactory.getLogger(AMRulesAggregatorProcessor.class);

  private int processorId;

  // Rules & default rule
  protected transient List<PassiveRule> ruleSet;
  protected transient ActiveRule defaultRule;
  protected transient int ruleNumberID;
  protected transient double[] statistics;

  // SAMOA Stream
  private Stream statisticsStream;
  private Stream resultStream;

  // Options
  protected int pageHinckleyThreshold;
  protected double pageHinckleyAlpha;
  protected boolean driftDetection;
  protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
  protected boolean constantLearningRatioDecay;
  protected double learningRatio;

  protected double splitConfidence;
  protected double tieThreshold;
  protected int gracePeriod;

  protected boolean noAnomalyDetection;
  protected double multivariateAnomalyProbabilityThreshold;
  protected double univariateAnomalyprobabilityThreshold;
  protected int anomalyNumInstThreshold;

  protected boolean unorderedRules;

  protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
  protected int voteType;

  /*
   * Constructor
   */
  public AMRulesAggregatorProcessor(Builder builder) {
    this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
    this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
    this.driftDetection = builder.driftDetection;
    this.predictionFunction = builder.predictionFunction;
    this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
    this.learningRatio = builder.learningRatio;
    this.splitConfidence = builder.splitConfidence;
    this.tieThreshold = builder.tieThreshold;
    this.gracePeriod = builder.gracePeriod;

    this.noAnomalyDetection = builder.noAnomalyDetection;
    this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
    this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
    this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
    this.unorderedRules = builder.unorderedRules;

    this.numericObserver = builder.numericObserver;
    this.voteType = builder.voteType;
  }

  /*
   * Process
   */
  @Override
  public boolean process(ContentEvent event) {
    if (event instanceof InstanceContentEvent) {
      InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
      this.processInstanceEvent(instanceEvent);
    }
    else if (event instanceof PredicateContentEvent) {
      this.updateRuleSplitNode((PredicateContentEvent) event);
    }
    else if (event instanceof RuleContentEvent) {
      RuleContentEvent rce = (RuleContentEvent) event;
      if (rce.isRemoving()) {
        this.removeRule(rce.getRuleNumberID());
      }
    }

    return true;
  }

  // Merge predict and train so we only check for covering rules one time
  private void processInstanceEvent(InstanceContentEvent instanceEvent) {
    Instance instance = instanceEvent.getInstance();
    boolean predictionCovered = false;
    boolean trainingCovered = false;
    boolean continuePrediction = instanceEvent.isTesting();
    boolean continueTraining = instanceEvent.isTraining();

    ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
    Iterator<PassiveRule> ruleIterator = this.ruleSet.iterator();
    while (ruleIterator.hasNext()) {
      if (!continuePrediction && !continueTraining)
        break;

      PassiveRule rule = ruleIterator.next();

      if (rule.isCovering(instance) == true) {
        predictionCovered = true;

        if (continuePrediction) {
          double[] vote = rule.getPrediction(instance);
          double error = rule.getCurrentError();
          errorWeightedVote.addVote(vote, error);
          if (!this.unorderedRules)
            continuePrediction = false;
        }

        if (continueTraining) {
          if (!isAnomaly(instance, rule)) {
            trainingCovered = true;
            rule.updateStatistics(instance);
            // Send instance to statistics PIs
            sendInstanceToRule(instance, rule.getRuleNumberID());

            if (!this.unorderedRules)
              continueTraining = false;
          }
        }
      }
    }

    if (predictionCovered) {
      // Combined prediction
      ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
      resultStream.put(rce);
    }
    else if (instanceEvent.isTesting()) {
      // predict with default rule
      double[] vote = defaultRule.getPrediction(instance);
      ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
      resultStream.put(rce);
    }

    if (!trainingCovered && instanceEvent.isTraining()) {
      // train default rule with this instance
      defaultRule.updateStatistics(instance);
      if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
        if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
          ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
              (RuleActiveRegressionNode) defaultRule.getLearningNode(),
              ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch
          defaultRule.split();
          defaultRule.setRuleNumberID(++ruleNumberID);
          this.ruleSet.add(new PassiveRule(this.defaultRule));
          // send to statistics PI
          sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
          defaultRule = newDefaultRule;
        }
      }
    }
  }

  /**
   * Helper method to generate new ResultContentEvent based on an instance and its prediction result.
   * 
   * @param prediction
   *          The predicted class label from the decision tree model.
   * @param inEvent
   *          The associated instance content event
   * @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
   */
  private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
    ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
        inEvent.getClassId(), prediction, inEvent.isLastEvent());
    rce.setClassifierIndex(this.processorId);
    rce.setEvaluationIndex(inEvent.getEvaluationIndex());
    return rce;
  }

  public ErrorWeightedVote newErrorWeightedVote() {
    if (voteType == 1)
      return new UniformWeightedVote();
    return new InverseErrorWeightedVote();
  }

  /**
   * Method to verify if the instance is an anomaly.
   * 
   * @param instance
   * @param rule
   * @return
   */
  private boolean isAnomaly(Instance instance, LearningRule rule) {
    // AMRUles is equipped with anomaly detection. If on, compute the anomaly
    // value.
    boolean isAnomaly = false;
    if (this.noAnomalyDetection == false) {
      if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
        isAnomaly = rule.isAnomaly(instance,
            this.univariateAnomalyprobabilityThreshold,
            this.multivariateAnomalyProbabilityThreshold,
            this.anomalyNumInstThreshold);
      }
    }
    return isAnomaly;
  }

  /*
   * Create new rules
   */
  private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
    ActiveRule r = newRule(ID);

    if (node != null)
    {
      if (node.getPerceptron() != null)
      {
        r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
        r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
      }
      if (statistics == null)
      {
        double mean;
        if (node.getNodeStatistics().getValue(0) > 0) {
          mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
          r.getLearningNode().getTargetMean().reset(mean, 1);
        }
      }
    }
    if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
    {
      double mean;
      if (statistics[0] > 0) {
        mean = statistics[1] / statistics[0];
        ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
      }
    }
    return r;
  }

  private ActiveRule newRule(int ID) {
    ActiveRule r = new ActiveRule.Builder().
        threshold(this.pageHinckleyThreshold).
        alpha(this.pageHinckleyAlpha).
        changeDetection(this.driftDetection).
        predictionFunction(this.predictionFunction).
        statistics(new double[3]).
        learningRatio(this.learningRatio).
        numericObserver(numericObserver).
        id(ID).build();
    return r;
  }

  /*
   * Add predicate/RuleSplitNode for a rule
   */
  private void updateRuleSplitNode(PredicateContentEvent pce) {
    int ruleID = pce.getRuleNumberID();
    for (PassiveRule rule : ruleSet) {
      if (rule.getRuleNumberID() == ruleID) {
        if (pce.getRuleSplitNode() != null)
          rule.nodeListAdd(pce.getRuleSplitNode());
        if (pce.getLearningNode() != null)
          rule.setLearningNode(pce.getLearningNode());
      }
    }
  }

  /*
   * Remove rule
   */
  private void removeRule(int ruleID) {
    for (PassiveRule rule : ruleSet) {
      if (rule.getRuleNumberID() == ruleID) {
        ruleSet.remove(rule);
        break;
      }
    }
  }

  @Override
  public void onCreate(int id) {
    this.processorId = id;
    this.statistics = new double[] { 0.0, 0, 0 };
    this.ruleNumberID = 0;
    this.defaultRule = newRule(++this.ruleNumberID);

    this.ruleSet = new LinkedList<PassiveRule>();
  }

  /*
   * Clone processor
   */
  @Override
  public Processor newProcessor(Processor p) {
    AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p;
    Builder builder = new Builder(oldProcessor);
    AMRulesAggregatorProcessor newProcessor = builder.build();
    newProcessor.resultStream = oldProcessor.resultStream;
    newProcessor.statisticsStream = oldProcessor.statisticsStream;
    return newProcessor;
  }

  /*
   * Send events
   */
  private void sendInstanceToRule(Instance instance, int ruleID) {
    AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
    this.statisticsStream.put(ace);
  }

  private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
    RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
    this.statisticsStream.put(rce);
  }

  /*
   * Output streams
   */
  public void setStatisticsStream(Stream statisticsStream) {
    this.statisticsStream = statisticsStream;
  }

  public Stream getStatisticsStream() {
    return this.statisticsStream;
  }

  public void setResultStream(Stream resultStream) {
    this.resultStream = resultStream;
  }

  public Stream getResultStream() {
    return this.resultStream;
  }

  /*
   * Others
   */
  public boolean isRandomizable() {
    return true;
  }

  /*
   * Builder
   */
  public static class Builder {
    private int pageHinckleyThreshold;
    private double pageHinckleyAlpha;
    private boolean driftDetection;
    private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
    private boolean constantLearningRatioDecay;
    private double learningRatio;
    private double splitConfidence;
    private double tieThreshold;
    private int gracePeriod;

    private boolean noAnomalyDetection;
    private double multivariateAnomalyProbabilityThreshold;
    private double univariateAnomalyprobabilityThreshold;
    private int anomalyNumInstThreshold;

    private boolean unorderedRules;

    private FIMTDDNumericAttributeClassLimitObserver numericObserver;
    private int voteType;

    private Instances dataset;

    public Builder(Instances dataset) {
      this.dataset = dataset;
    }

    public Builder(AMRulesAggregatorProcessor processor) {
      this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
      this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
      this.driftDetection = processor.driftDetection;
      this.predictionFunction = processor.predictionFunction;
      this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
      this.learningRatio = processor.learningRatio;
      this.splitConfidence = processor.splitConfidence;
      this.tieThreshold = processor.tieThreshold;
      this.gracePeriod = processor.gracePeriod;

      this.noAnomalyDetection = processor.noAnomalyDetection;
      this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
      this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
      this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
      this.unorderedRules = processor.unorderedRules;

      this.numericObserver = processor.numericObserver;
      this.voteType = processor.voteType;
    }

    public Builder threshold(int threshold) {
      this.pageHinckleyThreshold = threshold;
      return this;
    }

    public Builder alpha(double alpha) {
      this.pageHinckleyAlpha = alpha;
      return this;
    }

    public Builder changeDetection(boolean changeDetection) {
      this.driftDetection = changeDetection;
      return this;
    }

    public Builder predictionFunction(int predictionFunction) {
      this.predictionFunction = predictionFunction;
      return this;
    }

    public Builder constantLearningRatioDecay(boolean constantDecay) {
      this.constantLearningRatioDecay = constantDecay;
      return this;
    }

    public Builder learningRatio(double learningRatio) {
      this.learningRatio = learningRatio;
      return this;
    }

    public Builder splitConfidence(double splitConfidence) {
      this.splitConfidence = splitConfidence;
      return this;
    }

    public Builder tieThreshold(double tieThreshold) {
      this.tieThreshold = tieThreshold;
      return this;
    }

    public Builder gracePeriod(int gracePeriod) {
      this.gracePeriod = gracePeriod;
      return this;
    }

    public Builder noAnomalyDetection(boolean noAnomalyDetection) {
      this.noAnomalyDetection = noAnomalyDetection;
      return this;
    }

    public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
      this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
      return this;
    }

    public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
      this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
      return this;
    }

    public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
      this.anomalyNumInstThreshold = anomalyNumInstThreshold;
      return this;
    }

    public Builder unorderedRules(boolean unorderedRules) {
      this.unorderedRules = unorderedRules;
      return this;
    }

    public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
      this.numericObserver = numericObserver;
      return this;
    }

    public Builder voteType(int voteType) {
      this.voteType = voteType;
      return this;
    }

    public AMRulesAggregatorProcessor build() {
      return new AMRulesAggregatorProcessor(this);
    }
  }

}
