/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.genetic;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import javax.cache.Cache.Entry;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.IgniteLogger;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.cache.query.SqlQuery;
import org.apache.ignite.ml.genetic.cache.GeneCacheConfig;
import org.apache.ignite.ml.genetic.cache.PopulationCacheConfig;
import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
import org.apache.ignite.ml.genetic.parameter.GAGridConstants;

/**
 * Central class responsible for orchestrating distributive Genetic Algorithm.
 *
 * This class accepts a GAConfigriation and Ignite instance.
 */
public class GAGrid {
    /** Ignite logger */
    private IgniteLogger igniteLog;

    /** GAConfiguraton */
    private GAConfiguration cfg;

    /** Ignite instance */
    private Ignite ignite;

    /** Population cache */
    private IgniteCache<Long, Chromosome> populationCache;

    /** Gene cache */
    private IgniteCache<Long, Gene> geneCache;

    /**
     * @param cfg GAConfiguration
     * @param ignite Ignite
     */
    public GAGrid(GAConfiguration cfg, Ignite ignite) {
        this.ignite = ignite;
        this.cfg = cfg;
        this.ignite = ignite;
        this.igniteLog = ignite.log();

        // Get/Create cache
        populationCache = this.ignite.getOrCreateCache(PopulationCacheConfig.populationCache());
        populationCache.clear();

        // Get/Create cache
        geneCache = this.ignite.getOrCreateCache(GeneCacheConfig.geneCache());
        geneCache.clear();
    }

    /**
     * Calculate average fitness value
     *
     * @return Average fitness score
     */
    private Double calculateAverageFitness() {
        double avgFitnessScore = 0;

        IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);

        // Execute query calculate average fitness
        SqlFieldsQuery sql = new SqlFieldsQuery("select AVG(FITNESSSCORE) from Chromosome");

        // Iterate over the result set.
        try (QueryCursor<List<?>> cursor = cache.query(sql)) {
            for (List<?> row : cursor)
                avgFitnessScore = (Double)row.get(0);
        }

        return avgFitnessScore;
    }

    /**
     * Calculate fitness each Chromosome in population
     *
     * @param chromosomeKeys List of chromosome primary keys
     */
    private void calculateFitness(List<Long> chromosomeKeys) {
       this.ignite.compute().execute(new FitnessTask(this.cfg), chromosomeKeys);
    }

    /**
     * @param fittestKeys List of chromosome keys that will be copied from
     * @param selectedKeys List of chromosome keys that will be overwritten evenly by fittestKeys
     * @return Boolean value
     */
    private Boolean copyFitterChromosomesToPopulation(List<Long> fittestKeys, List<Long> selectedKeys) {
        double truncatePercentage = this.cfg.getTruncateRate();

        int totalSize = this.cfg.getPopulationSize();

        int truncateCnt = (int)(truncatePercentage * totalSize);

        int numOfCopies = selectedKeys.size() / truncateCnt;

        return this.ignite.compute()
            .execute(new TruncateSelectionTask(fittestKeys, numOfCopies), selectedKeys);
    }

    /**
     * create a Chromsome
     *
     * @param numOfGenes Number of Genes in resepective Chromosome
     * @return Chromosome
     */
    private Chromosome createChromosome(int numOfGenes) {
        long[] genes = new long[numOfGenes];
        List<Long> keys = new ArrayList<>();
        int k = 0;
        while (k < numOfGenes) {
            long key = selectGene(k);

            if (!(keys.contains(key))) {
                genes[k] = key;
                keys.add(key);
                k += 1;
            }
        }
        return new Chromosome(genes);
    }

    /**
     * Perform crossover
     *
     * @param leastFitKeys List of primary keys for Chromosomes that are considered 'least fit'
     */
    private void crossover(List<Long> leastFitKeys) {
        this.ignite.compute().execute(new CrossOverTask(this.cfg), leastFitKeys);
    }

    /**
     * Evolve the population
     *
     * @return Fittest Chromosome
     */
    public Chromosome evolve() {
        // keep track of current generation
        int generationCnt = 1;

        Chromosome fittestChromosome;

        initializeGenePopulation();

        initializePopulation();

        // Calculate Fitness
        calculateFitness(getPopulationKeys());

        // Retrieve chromosomes in order by fitness value
        LinkedHashMap<Long, Double> map = getChromosomesByFittest();

        // Calculate average fitness value of population
        double averageFitnessScore = calculateAverageFitness();

        Long key = map.keySet().iterator().next();

        fittestChromosome = populationCache.get(key);

        // while NOT terminateCondition met
        while (!(cfg.getTerminateCriteria().isTerminationConditionMet(fittestChromosome, averageFitnessScore,
            generationCnt))) {
            generationCnt += 1;

            // We will crossover/mutate over chromosomes based on selection method
            List<Long> selectedKeystoreCrossMutation = selection(map);

            // Cross Over
            crossover(selectedKeystoreCrossMutation);

            // Mutate
            mutation(selectedKeystoreCrossMutation);

            // Calculate Fitness
            calculateFitness(selectedKeystoreCrossMutation);

            // Retrieve chromosomes in order by fitness value
            map = getChromosomesByFittest();

            key = map.keySet().iterator().next();

            // Retreive the first chromosome from the list
            fittestChromosome = populationCache.get(key);

            // Calculate average fitness value of population
            averageFitnessScore = calculateAverageFitness();

            // End Loop

        }
        return fittestChromosome;
    }

    /**
     * helper routine to retrieve Chromosome keys in order of fittest
     *
     * @return Map of primary key/fitness score pairs for chromosomes.
     */
    private LinkedHashMap<Long,Double> getChromosomesByFittest() {
        LinkedHashMap<Long, Double> orderChromKeysByFittest = new LinkedHashMap<>();

        String orderDirection = "desc";

        if (!cfg.isHigherFitnessValFitter())
            orderDirection = "asc";

        String fittestSQL = "select _key, fitnessScore from Chromosome order by fitnessScore " + orderDirection;

        // Execute query to retrieve keys for ALL Chromosomes by fittnessScore
        QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));

        List<List<?>> res = cursor.getAll();

        for (List row : res) {
            Long key = (Long)row.get(0);
            Double fitnessScore = (Double)row.get(1);
            orderChromKeysByFittest.put(key, fitnessScore);
        }

        return orderChromKeysByFittest;
    }

    /**
     * @param keys List of primary keys for respective Chromosomes
     * @return List of keys for respective Chromosomes
     */
    private List<Long> getFittestKeysForTruncation(List<Long> keys) {
        double truncatePercentage = this.cfg.getTruncateRate();

        int truncateCnt = (int)(truncatePercentage * keys.size());

        return keys.subList(0, truncateCnt);
    }

    /**
     * initialize the Gene pool
     */
    void initializeGenePopulation() {
        geneCache.clear();

        List<Gene> genePool = cfg.getGenePool();

        for (Gene gene : genePool)
            geneCache.put(gene.id(), gene);
    }

    /**
     * Initialize the population of Chromosomes
     */
    void initializePopulation() {
        int populationSize = cfg.getPopulationSize();
        populationCache.clear();

        for (int j = 0; j < populationSize; j++) {
            Chromosome chromosome = createChromosome(cfg.getChromosomeLen());
            populationCache.put(chromosome.id(), chromosome);
        }

    }

    /**
     * Perform mutation
     *
     * @param leastFitKeys List of primary keys for Chromosomes that are considered 'least fit'.
     */
    private void mutation(List<Long> leastFitKeys) {
         this.ignite.compute().execute(new MutateTask(this.cfg), leastFitKeys);
    }

    /**
     * select a gene from the Gene pool
     *
     * @return Primary key of respective Gene
     */
    private long selectAnyGene() {
        int idx = selectRandomIndex(cfg.getGenePool().size());
        Gene gene = cfg.getGenePool().get(idx);
        return gene.id();
    }

    /**
     * For our implementation we consider 'best fit' chromosomes, by selecting least fit chromosomes for crossover and
     * mutation
     *
     * As result, we are interested in least fit chromosomes.
     *
     * @param keys List of primary keys for respective Chromosomes
     * @return List of primary Keys for respective Chromosomes that are considered least fit
     */
    private List<Long> selectByElitism(List<Long> keys) {
        int elitismCnt = this.cfg.getElitismCnt();
        return keys.subList(elitismCnt, keys.size());
    }

    /**
     * Truncation selection simply retains the fittest x% of the population. These fittest individuals are duplicated so
     * that the population size is maintained.
     *
     * @param keys Keys.
     * @return List of keys
     */
    private List<Long> selectByTruncation(List<Long> keys) {
        double truncatePercentage = this.cfg.getTruncateRate();

        int truncateCnt = (int)(truncatePercentage * keys.size());

        return keys.subList(truncateCnt, keys.size());
    }

    /**
     * Roulette Wheel selection
     *
     * @param map Map of keys/fitness scores
     * @return List of primary Keys for respective chromosomes that will breed
     */
    private List<Long> selectByRouletteWheel(LinkedHashMap map) {
        List<Long> populationKeys = this.ignite.compute().execute(new RouletteWheelSelectionTask(this.cfg), map);

        return populationKeys;
    }

    /**
     * @param k Gene index in Chromosome.
     * @return Primary key of respective Gene chosen
     */
    private long selectGene(int k) {
        if (cfg.getChromosomeCriteria() == null)
            return (selectAnyGene());
        else
            return (selectGeneByChromosomeCriteria(k));
    }

    /**
     * method assumes ChromosomeCriteria is set.
     *
     * @param k Gene index in Chromosome
     * @return Primary key of respective Gene
     */
    private long selectGeneByChromosomeCriteria(int k) {
        List<Gene> genes = new ArrayList<>();

        StringBuffer sbSqlClause = new StringBuffer("_val like '");
        sbSqlClause.append("%");
        sbSqlClause.append(cfg.getChromosomeCriteria().getCriteria().get(k));
        sbSqlClause.append("%'");

        IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.GENE_CACHE);

        SqlQuery sql = new SqlQuery(Gene.class, sbSqlClause.toString());

        try (QueryCursor<Entry<Long, Gene>> cursor = cache.query(sql)) {
            for (Entry<Long, Gene> e : cursor)
                genes.add(e.getValue());
        }

        int idx = selectRandomIndex(genes.size());

        Gene gene = genes.get(idx);
        return gene.id();
    }

    /**
     * @param sizeOfGenePool Size of Gene pool
     * @return Index
     */
    private int selectRandomIndex(int sizeOfGenePool) {
        Random randomGenerator = new Random();
        return randomGenerator.nextInt(sizeOfGenePool);
    }

    /**
     * Select chromosomes
     *
     * @param map Map of keys/fitness scores for respective Chromosomes
     * @return List of primary keys for respective Chromosomes
     */
    private List<Long> selection(LinkedHashMap map) {
        List<Long> selectedKeys = new ArrayList<>();

        // We will crossover/mutate over chromosomes based on selection method
        List<Long> chromosomeKeys = new ArrayList<>(map.keySet());

        GAGridConstants.SELECTION_METHOD selectionMtd = cfg.getSelectionMtd();

        switch (selectionMtd) {
            case SELECTION_METHOD_ELITISM:
                selectedKeys = selectByElitism(chromosomeKeys);
                break;
            case SELECTION_METHOD_TRUNCATION:
                selectedKeys = selectByTruncation(chromosomeKeys);

                List<Long> fittestKeys = getFittestKeysForTruncation(chromosomeKeys);

                copyFitterChromosomesToPopulation(fittestKeys, selectedKeys);

                // copy more fit keys to rest of population
                break;
            case SELECTION_METHOD_ROULETTE_WHEEL:
              selectedKeys = this.selectByRouletteWheel(map);

            default:
                break;
        }

        return selectedKeys;
    }

    /**
     * Get primary keys for Chromosomes
     *
     * @return List of Chromosome primary keys
     */
    List<Long> getPopulationKeys() {
        String fittestSQL = "select _key from Chromosome";

        // Execute query to retrieve keys for ALL Chromosomes
         QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));

         List<List<?>> res = cursor.getAll();

        return (List<Long>) res.stream().map(x -> x.get(0)).collect(Collectors.toList());
    }

}
