package org.apache.samoa.evaluation;

import org.apache.samoa.instances.Attribute;

/*
 * #%L
 * SAMOA
 * %%
 * Copyright (C) 2014 - 2016 Apache Software Foundation
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */


import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
import org.apache.samoa.moa.core.Vote;

import java.util.Collections;
import java.util.List;
import java.util.Vector;

/**
 * Created by Edi Bice (edi.bice gmail com) on 2/22/2016.
 */
public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject implements
        ClassificationPerformanceEvaluator {

    private static final long serialVersionUID = 1L;
    protected int numClasses = -1;

    protected long[] support;
    protected long[] truePos;
    protected long[] falsePos;
    protected long[] trueNeg;
    protected long[] falseNeg;
    private String instanceIdentifier;
    private Instance lastSeenInstance;
    protected double[] classVotes;
    
    @Override
    public void reset() {
        reset(this.numClasses);
    }

    public void reset(int numClasses) {
        this.numClasses = numClasses;
        this.support = new long[numClasses];
        this.truePos = new long[numClasses];
        this.falsePos = new long[numClasses];
        this.trueNeg = new long[numClasses];
        this.falseNeg = new long[numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            this.support[i] = 0;
            this.truePos[i] = 0;
            this.falsePos[i] = 0;
            this.trueNeg[i] = 0;
            this.falseNeg[i] = 0;
        }
    }

    @Override
    public void addResult(Instance inst, double[] classVotes, String instanceIndex,
            long delay) {
        if (numClasses==-1) reset(inst.numClasses());
        int trueClass = (int) inst.classValue();
        this.support[trueClass] += 1;
        int predictedClass = Utils.maxIndex(classVotes);
        if (predictedClass == trueClass) {
            this.truePos[trueClass] += 1;
            for (int i = 0; i < this.numClasses; i++) {
                if (i!=predictedClass) this.trueNeg[i] += 1;
            }
        } else {
            this.falsePos[predictedClass] += 1;
            this.falseNeg[trueClass] += 1;
            for (int i = 0; i < this.numClasses; i++) {
                if (!(i==predictedClass || i==trueClass)) this.trueNeg[i] += 1;
            }
        }
    }

    @Override
    public Measurement[] getPerformanceMeasurements() {
        List<Measurement> measurements = new Vector<>();
        Collections.addAll(measurements, getSupportMeasurements());
        Collections.addAll(measurements, getPrecisionMeasurements());
        Collections.addAll(measurements, getRecallMeasurements());
        Collections.addAll(measurements, getF1Measurements());
        return measurements.toArray(new Measurement[measurements.size()]);
    }
    
    /**
     * This method is used to retrieve predictions and votes (for classification only)
     * 
     * @return String This returns an array of predictions and votes objects.
     */
    @Override
    public Vote[] getPredictionVotes() {
      Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute();
      double trueValue = this.lastSeenInstance.classValue();
      List<String> classAttributeValues = classAttribute.getAttributeValues();

      int trueNominalIndex = (int) trueValue;
      String trueNominalValue = classAttributeValues.get(trueNominalIndex);

      Vote[] votes = new Vote[classVotes.length + 3];
      votes[0] = new Vote("instance number",
          this.instanceIdentifier);
      votes[1] = new Vote("true class value",
          trueNominalValue);
      votes[2] = new Vote("predicted class value",
          classAttributeValues.get(Utils.maxIndex(classVotes)));

      for (int i = 0; i < classAttributeValues.size(); i++) {
        if (i < classVotes.length) {
          votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), classVotes[i]);
        } else {
          votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), 0);
        }
      }
      return votes;
    }

    private Measurement[] getSupportMeasurements() {
        Measurement[] measurements = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            String ml = String.format("class %s support", i);
            measurements[i] = new Measurement(ml, this.support[i]);
        }
        return measurements;
    }

    private Measurement[] getPrecisionMeasurements() {
        Measurement[] measurements = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            String ml = String.format("class %s precision", i);
            measurements[i] = new Measurement(ml, getPrecision(i), 10);
        }
        return measurements;
    }

    private Measurement[] getRecallMeasurements() {
        Measurement[] measurements = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            String ml = String.format("class %s recall", i);
            measurements[i] = new Measurement(ml, getRecall(i), 10);
        }
        return measurements;
    }

    private Measurement[] getF1Measurements() {
        Measurement[] measurements = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            String ml = String.format("class %s f1-score", i);
            measurements[i] = new Measurement(ml, getF1Score(i), 10);
        }
        return measurements;
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
        Measurement.getMeasurementsDescription(getSupportMeasurements(), sb, indent);
        Measurement.getMeasurementsDescription(getPrecisionMeasurements(), sb, indent);
        Measurement.getMeasurementsDescription(getRecallMeasurements(), sb, indent);
        Measurement.getMeasurementsDescription(getF1Measurements(), sb, indent);
    }

    private double getPrecision(int classIndex) {
        return (double) this.truePos[classIndex] / (this.truePos[classIndex] + this.falsePos[classIndex]);
    }

    private double getRecall(int classIndex) {
        return (double) this.truePos[classIndex] / (this.truePos[classIndex] + this.falseNeg[classIndex]);
    }

    private double getF1Score(int classIndex) {
        double precision = getPrecision(classIndex);
        double recall = getRecall(classIndex);
        return 2 * (precision * recall) / (precision + recall);
    }

}
