| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.ignite.ml.genetic; |
| |
| import java.util.ArrayList; |
| import java.util.LinkedHashMap; |
| import java.util.List; |
| import java.util.Random; |
| import java.util.stream.Collectors; |
| import javax.cache.Cache.Entry; |
| import org.apache.ignite.Ignite; |
| import org.apache.ignite.IgniteCache; |
| import org.apache.ignite.IgniteLogger; |
| import org.apache.ignite.cache.query.QueryCursor; |
| import org.apache.ignite.cache.query.SqlFieldsQuery; |
| import org.apache.ignite.cache.query.SqlQuery; |
| import org.apache.ignite.ml.genetic.cache.GeneCacheConfig; |
| import org.apache.ignite.ml.genetic.cache.PopulationCacheConfig; |
| import org.apache.ignite.ml.genetic.parameter.GAConfiguration; |
| import org.apache.ignite.ml.genetic.parameter.GAGridConstants; |
| |
| /** |
| * Central class responsible for orchestrating distributive Genetic Algorithm. |
| * |
| * This class accepts a GAConfigriation and Ignite instance. |
| */ |
| public class GAGrid { |
| /** Ignite logger */ |
| private IgniteLogger igniteLog; |
| |
| /** GAConfiguraton */ |
| private GAConfiguration cfg; |
| |
| /** Ignite instance */ |
| private Ignite ignite; |
| |
| /** Population cache */ |
| private IgniteCache<Long, Chromosome> populationCache; |
| |
| /** Gene cache */ |
| private IgniteCache<Long, Gene> geneCache; |
| |
| /** |
| * @param cfg GAConfiguration |
| * @param ignite Ignite |
| */ |
| public GAGrid(GAConfiguration cfg, Ignite ignite) { |
| this.ignite = ignite; |
| this.cfg = cfg; |
| this.ignite = ignite; |
| this.igniteLog = ignite.log(); |
| |
| // Get/Create cache |
| populationCache = this.ignite.getOrCreateCache(PopulationCacheConfig.populationCache()); |
| populationCache.clear(); |
| |
| // Get/Create cache |
| geneCache = this.ignite.getOrCreateCache(GeneCacheConfig.geneCache()); |
| geneCache.clear(); |
| } |
| |
| /** |
| * Calculate average fitness value |
| * |
| * @return Average fitness score |
| */ |
| private Double calculateAverageFitness() { |
| double avgFitnessScore = 0; |
| |
| IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.POPULATION_CACHE); |
| |
| // Execute query calculate average fitness |
| SqlFieldsQuery sql = new SqlFieldsQuery("select AVG(FITNESSSCORE) from Chromosome"); |
| |
| // Iterate over the result set. |
| try (QueryCursor<List<?>> cursor = cache.query(sql)) { |
| for (List<?> row : cursor) |
| avgFitnessScore = (Double)row.get(0); |
| } |
| |
| return avgFitnessScore; |
| } |
| |
| /** |
| * Calculate fitness each Chromosome in population |
| * |
| * @param chromosomeKeys List of chromosome primary keys |
| */ |
| private void calculateFitness(List<Long> chromosomeKeys) { |
| this.ignite.compute().execute(new FitnessTask(this.cfg), chromosomeKeys); |
| } |
| |
| /** |
| * @param fittestKeys List of chromosome keys that will be copied from |
| * @param selectedKeys List of chromosome keys that will be overwritten evenly by fittestKeys |
| * @return Boolean value |
| */ |
| private Boolean copyFitterChromosomesToPopulation(List<Long> fittestKeys, List<Long> selectedKeys) { |
| double truncatePercentage = this.cfg.getTruncateRate(); |
| |
| int totalSize = this.cfg.getPopulationSize(); |
| |
| int truncateCnt = (int)(truncatePercentage * totalSize); |
| |
| int numOfCopies = selectedKeys.size() / truncateCnt; |
| |
| return this.ignite.compute() |
| .execute(new TruncateSelectionTask(fittestKeys, numOfCopies), selectedKeys); |
| } |
| |
| /** |
| * create a Chromsome |
| * |
| * @param numOfGenes Number of Genes in resepective Chromosome |
| * @return Chromosome |
| */ |
| private Chromosome createChromosome(int numOfGenes) { |
| long[] genes = new long[numOfGenes]; |
| List<Long> keys = new ArrayList<>(); |
| int k = 0; |
| while (k < numOfGenes) { |
| long key = selectGene(k); |
| |
| if (!(keys.contains(key))) { |
| genes[k] = key; |
| keys.add(key); |
| k += 1; |
| } |
| } |
| return new Chromosome(genes); |
| } |
| |
| /** |
| * Perform crossover |
| * |
| * @param leastFitKeys List of primary keys for Chromosomes that are considered 'least fit' |
| */ |
| private void crossover(List<Long> leastFitKeys) { |
| this.ignite.compute().execute(new CrossOverTask(this.cfg), leastFitKeys); |
| } |
| |
| /** |
| * Evolve the population |
| * |
| * @return Fittest Chromosome |
| */ |
| public Chromosome evolve() { |
| // keep track of current generation |
| int generationCnt = 1; |
| |
| Chromosome fittestChromosome; |
| |
| initializeGenePopulation(); |
| |
| initializePopulation(); |
| |
| // Calculate Fitness |
| calculateFitness(getPopulationKeys()); |
| |
| // Retrieve chromosomes in order by fitness value |
| LinkedHashMap<Long, Double> map = getChromosomesByFittest(); |
| |
| // Calculate average fitness value of population |
| double averageFitnessScore = calculateAverageFitness(); |
| |
| Long key = map.keySet().iterator().next(); |
| |
| fittestChromosome = populationCache.get(key); |
| |
| // while NOT terminateCondition met |
| while (!(cfg.getTerminateCriteria().isTerminationConditionMet(fittestChromosome, averageFitnessScore, |
| generationCnt))) { |
| generationCnt += 1; |
| |
| // We will crossover/mutate over chromosomes based on selection method |
| List<Long> selectedKeystoreCrossMutation = selection(map); |
| |
| // Cross Over |
| crossover(selectedKeystoreCrossMutation); |
| |
| // Mutate |
| mutation(selectedKeystoreCrossMutation); |
| |
| // Calculate Fitness |
| calculateFitness(selectedKeystoreCrossMutation); |
| |
| // Retrieve chromosomes in order by fitness value |
| map = getChromosomesByFittest(); |
| |
| key = map.keySet().iterator().next(); |
| |
| // Retreive the first chromosome from the list |
| fittestChromosome = populationCache.get(key); |
| |
| // Calculate average fitness value of population |
| averageFitnessScore = calculateAverageFitness(); |
| |
| // End Loop |
| |
| } |
| return fittestChromosome; |
| } |
| |
| /** |
| * helper routine to retrieve Chromosome keys in order of fittest |
| * |
| * @return Map of primary key/fitness score pairs for chromosomes. |
| */ |
| private LinkedHashMap<Long,Double> getChromosomesByFittest() { |
| LinkedHashMap<Long, Double> orderChromKeysByFittest = new LinkedHashMap<>(); |
| |
| String orderDirection = "desc"; |
| |
| if (!cfg.isHigherFitnessValFitter()) |
| orderDirection = "asc"; |
| |
| String fittestSQL = "select _key, fitnessScore from Chromosome order by fitnessScore " + orderDirection; |
| |
| // Execute query to retrieve keys for ALL Chromosomes by fittnessScore |
| QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL)); |
| |
| List<List<?>> res = cursor.getAll(); |
| |
| for (List row : res) { |
| Long key = (Long)row.get(0); |
| Double fitnessScore = (Double)row.get(1); |
| orderChromKeysByFittest.put(key, fitnessScore); |
| } |
| |
| return orderChromKeysByFittest; |
| } |
| |
| /** |
| * @param keys List of primary keys for respective Chromosomes |
| * @return List of keys for respective Chromosomes |
| */ |
| private List<Long> getFittestKeysForTruncation(List<Long> keys) { |
| double truncatePercentage = this.cfg.getTruncateRate(); |
| |
| int truncateCnt = (int)(truncatePercentage * keys.size()); |
| |
| return keys.subList(0, truncateCnt); |
| } |
| |
| /** |
| * initialize the Gene pool |
| */ |
| void initializeGenePopulation() { |
| geneCache.clear(); |
| |
| List<Gene> genePool = cfg.getGenePool(); |
| |
| for (Gene gene : genePool) |
| geneCache.put(gene.id(), gene); |
| } |
| |
| /** |
| * Initialize the population of Chromosomes |
| */ |
| void initializePopulation() { |
| int populationSize = cfg.getPopulationSize(); |
| populationCache.clear(); |
| |
| for (int j = 0; j < populationSize; j++) { |
| Chromosome chromosome = createChromosome(cfg.getChromosomeLen()); |
| populationCache.put(chromosome.id(), chromosome); |
| } |
| |
| } |
| |
| /** |
| * Perform mutation |
| * |
| * @param leastFitKeys List of primary keys for Chromosomes that are considered 'least fit'. |
| */ |
| private void mutation(List<Long> leastFitKeys) { |
| this.ignite.compute().execute(new MutateTask(this.cfg), leastFitKeys); |
| } |
| |
| /** |
| * select a gene from the Gene pool |
| * |
| * @return Primary key of respective Gene |
| */ |
| private long selectAnyGene() { |
| int idx = selectRandomIndex(cfg.getGenePool().size()); |
| Gene gene = cfg.getGenePool().get(idx); |
| return gene.id(); |
| } |
| |
| /** |
| * For our implementation we consider 'best fit' chromosomes, by selecting least fit chromosomes for crossover and |
| * mutation |
| * |
| * As result, we are interested in least fit chromosomes. |
| * |
| * @param keys List of primary keys for respective Chromosomes |
| * @return List of primary Keys for respective Chromosomes that are considered least fit |
| */ |
| private List<Long> selectByElitism(List<Long> keys) { |
| int elitismCnt = this.cfg.getElitismCnt(); |
| return keys.subList(elitismCnt, keys.size()); |
| } |
| |
| /** |
| * Truncation selection simply retains the fittest x% of the population. These fittest individuals are duplicated so |
| * that the population size is maintained. |
| * |
| * @param keys Keys. |
| * @return List of keys |
| */ |
| private List<Long> selectByTruncation(List<Long> keys) { |
| double truncatePercentage = this.cfg.getTruncateRate(); |
| |
| int truncateCnt = (int)(truncatePercentage * keys.size()); |
| |
| return keys.subList(truncateCnt, keys.size()); |
| } |
| |
| /** |
| * Roulette Wheel selection |
| * |
| * @param map Map of keys/fitness scores |
| * @return List of primary Keys for respective chromosomes that will breed |
| */ |
| private List<Long> selectByRouletteWheel(LinkedHashMap map) { |
| List<Long> populationKeys = this.ignite.compute().execute(new RouletteWheelSelectionTask(this.cfg), map); |
| |
| return populationKeys; |
| } |
| |
| /** |
| * @param k Gene index in Chromosome. |
| * @return Primary key of respective Gene chosen |
| */ |
| private long selectGene(int k) { |
| if (cfg.getChromosomeCriteria() == null) |
| return (selectAnyGene()); |
| else |
| return (selectGeneByChromosomeCriteria(k)); |
| } |
| |
| /** |
| * method assumes ChromosomeCriteria is set. |
| * |
| * @param k Gene index in Chromosome |
| * @return Primary key of respective Gene |
| */ |
| private long selectGeneByChromosomeCriteria(int k) { |
| List<Gene> genes = new ArrayList<>(); |
| |
| StringBuffer sbSqlClause = new StringBuffer("_val like '"); |
| sbSqlClause.append("%"); |
| sbSqlClause.append(cfg.getChromosomeCriteria().getCriteria().get(k)); |
| sbSqlClause.append("%'"); |
| |
| IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.GENE_CACHE); |
| |
| SqlQuery sql = new SqlQuery(Gene.class, sbSqlClause.toString()); |
| |
| try (QueryCursor<Entry<Long, Gene>> cursor = cache.query(sql)) { |
| for (Entry<Long, Gene> e : cursor) |
| genes.add(e.getValue()); |
| } |
| |
| int idx = selectRandomIndex(genes.size()); |
| |
| Gene gene = genes.get(idx); |
| return gene.id(); |
| } |
| |
| /** |
| * @param sizeOfGenePool Size of Gene pool |
| * @return Index |
| */ |
| private int selectRandomIndex(int sizeOfGenePool) { |
| Random randomGenerator = new Random(); |
| return randomGenerator.nextInt(sizeOfGenePool); |
| } |
| |
| /** |
| * Select chromosomes |
| * |
| * @param map Map of keys/fitness scores for respective Chromosomes |
| * @return List of primary keys for respective Chromosomes |
| */ |
| private List<Long> selection(LinkedHashMap map) { |
| List<Long> selectedKeys = new ArrayList<>(); |
| |
| // We will crossover/mutate over chromosomes based on selection method |
| List<Long> chromosomeKeys = new ArrayList<>(map.keySet()); |
| |
| GAGridConstants.SELECTION_METHOD selectionMtd = cfg.getSelectionMtd(); |
| |
| switch (selectionMtd) { |
| case SELECTION_METHOD_ELITISM: |
| selectedKeys = selectByElitism(chromosomeKeys); |
| break; |
| case SELECTION_METHOD_TRUNCATION: |
| selectedKeys = selectByTruncation(chromosomeKeys); |
| |
| List<Long> fittestKeys = getFittestKeysForTruncation(chromosomeKeys); |
| |
| copyFitterChromosomesToPopulation(fittestKeys, selectedKeys); |
| |
| // copy more fit keys to rest of population |
| break; |
| case SELECTION_METHOD_ROULETTE_WHEEL: |
| selectedKeys = this.selectByRouletteWheel(map); |
| |
| default: |
| break; |
| } |
| |
| return selectedKeys; |
| } |
| |
| /** |
| * Get primary keys for Chromosomes |
| * |
| * @return List of Chromosome primary keys |
| */ |
| List<Long> getPopulationKeys() { |
| String fittestSQL = "select _key from Chromosome"; |
| |
| // Execute query to retrieve keys for ALL Chromosomes |
| QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL)); |
| |
| List<List<?>> res = cursor.getAll(); |
| |
| return (List<Long>) res.stream().map(x -> x.get(0)).collect(Collectors.toList()); |
| } |
| |
| } |