# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from ast import literal_eval
from datetime import datetime
from hyperopt import hp, rand, tpe, atpe, Trials, STATUS_OK, STATUS_RUNNING
from hyperopt.base import Domain
import math
import numpy as np
import plpy
import time

from madlib_keras_validator import MstLoaderInputValidator
# from utilities.admin import cleanup_madlib_temp_tables
from utilities.utilities import get_current_timestamp, get_seg_number, get_segments_per_host, \
    unique_string, add_postfix, extract_keyvalue_params, _assert, _assert_equal, rename_table
from utilities.control import SetGUC
from madlib_keras_fit_multiple_model import FitMultipleModel
from madlib_keras_helper import generate_row_string
from madlib_keras_helper import DISTRIBUTION_RULES
from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
from keras_model_arch_table import ModelArchSchema
from utilities.validate_args import table_exists, drop_tables, input_tbl_valid
from utilities.validate_args import quote_ident

class AutoMLConstants:
    BRACKET = 's'
    ROUND = 'i'
    CONFIGURATIONS = 'n_i'
    RESOURCES = 'r_i'
    HYPERBAND = 'hyperband'
    HYPEROPT = 'hyperopt'
    R = 'R'
    ETA = 'eta'
    SKIP_LAST = 'skip_last'
    LOSS_METRIC = 'training_loss_final'
    TEMP_MST_TABLE = unique_string('temp_mst_table')
    TEMP_MST_SUMMARY_TABLE = add_postfix(TEMP_MST_TABLE, '_summary')
    TEMP_OUTPUT_TABLE = unique_string('temp_output_table')
    METRICS_ITERS = 'metrics_iters' # custom column
    NUM_CONFIGS = 'num_configs'
    NUM_ITERS = 'num_iterations'
    ALGORITHM = 'algorithm'
    TIME_FORMAT = '%Y-%m-%d %H:%M:%S'
    INT_MAX = 2 ** 31 - 1
    TARGET_SCHEMA = 'public'

class HyperbandSchedule():
    """The utility class for loading a hyperband schedule table with algorithm inputs.

    Attributes:
        schedule_table (string): Name of output table containing hyperband schedule.
        R (int): Maximum number of resources (iterations) that can be allocated
  to a single configuration.
        eta (int): Controls the proportion of configurations discarded in
  each round of successive halving.
        skip_last (int): The number of last rounds to skip.
    """
    def __init__(self, schedule_table, R, eta=3, skip_last=0):
        self.schedule_table = schedule_table # table name to store hyperband schedule
        self.R = R # maximum iterations/epochs allocated to a configuration
        self.eta = eta # defines downsampling rate
        self.skip_last = skip_last
        self.module_name = 'hyperband_schedule'
        self.validate_inputs()

        # number of unique executions of Successive Halving (minus one)
        self.s_max = int(math.floor(math.log(self.R, self.eta)))
        self.validate_s_max()

        self.schedule_vals = []

        self.calculate_schedule()

    def load(self):
        """
        The entry point for loading the hyperband schedule table.
        """
        self.create_schedule_table()
        self.insert_into_schedule_table()

    def validate_inputs(self):
        """
        Validates user input values
        """
        _assert(self.eta > 1, "{0}: eta must be greater than 1".format(self.module_name))
        _assert(self.R >= self.eta, "{0}: R should not be less than eta".format(self.module_name))

    def validate_s_max(self):
        _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "{0}: skip_last must be " +
                "non-negative and less than {1}".format(self.module_name,self.s_max))

    def calculate_schedule(self):
        """
        Calculates the hyperband schedule (number of configs and allocated resources)
        in each round of each bracket and skips the number of last rounds specified in 'skip_last'
        """
        for s in reversed(range(self.s_max+1)):
            n = int(math.ceil(int((self.s_max+1)/(s+1))*math.pow(self.eta, s))) # initial number of configurations
            r = self.R * math.pow(self.eta, -s)

            for i in range((s+1) - int(self.skip_last)):
                # Computing each of the
                n_i = n*math.pow(self.eta, -i)
                r_i = r*math.pow(self.eta, i)

                self.schedule_vals.append({AutoMLConstants.BRACKET: s,
                                           AutoMLConstants.ROUND: i,
                                           AutoMLConstants.CONFIGURATIONS: int(n_i),
                                           AutoMLConstants.RESOURCES: int(round(r_i))})

    def create_schedule_table(self):
        """Initializes the output schedule table"""
        create_query = """
                        CREATE TABLE {self.schedule_table} (
                            {s} INTEGER,
                            {i} INTEGER,
                            {n_i} INTEGER,
                            {r_i} INTEGER,
                            unique ({s}, {i})
                        );
                       """.format(self=self,
                                  s=AutoMLConstants.BRACKET,
                                  i=AutoMLConstants.ROUND,
                                  n_i=AutoMLConstants.CONFIGURATIONS,
                                  r_i=AutoMLConstants.RESOURCES)
        plpy.execute(create_query)

    def insert_into_schedule_table(self):
        """Insert everything in self.schedule_vals into the output schedule table."""
        for sd in self.schedule_vals:
            sd_s = sd[AutoMLConstants.BRACKET]
            sd_i = sd[AutoMLConstants.ROUND]
            sd_n_i = sd[AutoMLConstants.CONFIGURATIONS]
            sd_r_i = sd[AutoMLConstants.RESOURCES]
            insert_query = """
                            INSERT INTO
                                {self.schedule_table}(
                                    {s_col},
                                    {i_col},
                                    {n_i_col},
                                    {r_i_col}
                                )
                            VALUES (
                                {sd_s},
                                {sd_i},
                                {sd_n_i},
                                {sd_r_i}
                            )
                           """.format(s_col=AutoMLConstants.BRACKET,
                                      i_col=AutoMLConstants.ROUND,
                                      n_i_col=AutoMLConstants.CONFIGURATIONS,
                                      r_i_col=AutoMLConstants.RESOURCES,
                                      **locals())
            plpy.execute(insert_query)

class KerasAutoML(object):
    """
    The core AutoML class for running AutoML algorithms such as Hyperband and Hyperopt.
    """
    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
                 automl_params=None, random_state=None, object_table=None,
                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
                 name=None, description=None, **kwargs):
        self.schema_madlib = schema_madlib
        self.source_table = source_table
        self.model_output_table = model_output_table
        self.module_name = 'madlib_keras_automl'
        input_tbl_valid(self.source_table, self.module_name)
        if self.model_output_table:
            self.model_info_table = add_postfix(self.model_output_table, '_info')
            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
        self.model_arch_table = model_arch_table
        self.model_selection_table = model_selection_table
        self.model_selection_summary_table = add_postfix(
            model_selection_table, "_summary")
        self.model_id_list = sorted(list(set(model_id_list)))
        self.compile_params_grid = compile_params_grid
        self.fit_params_grid = fit_params_grid

        if object_table is not None:
            object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))

        MstLoaderInputValidator(
            schema_madlib=self.schema_madlib,
            model_arch_table=self.model_arch_table,
            model_selection_table=self.model_selection_table,
            model_selection_summary_table=self.model_selection_summary_table,
            model_id_list=self.model_id_list,
            compile_params_list=compile_params_grid,
            fit_params_list=fit_params_grid,
            object_table=object_table,
            module_name='madlib_keras_automl'
        )

        self.automl_method = automl_method
        self.automl_params = automl_params
        self.random_state = random_state

        self.object_table = object_table
        self.use_gpus = use_gpus if use_gpus else False
        self.validation_table = validation_table
        self.metrics_compute_frequency = metrics_compute_frequency
        self.name = name
        self.description = description

        if self.validation_table:
            AutoMLConstants.LOSS_METRIC = 'validation_loss_final'

    def create_model_output_table(self):
        output_table_create_query = """
                                    CREATE TABLE {self.model_output_table}
                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
                                     {ModelArchSchema.MODEL_ARCH} JSON)
                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
                                               ModelArchSchema=ModelArchSchema)
        plpy.execute(output_table_create_query)

    def create_model_output_info_table(self):
        info_table_create_query = """
                                  CREATE TABLE {self.model_info_table}
                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
                                   {ModelArchSchema.MODEL_ID} INTEGER,
                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
                                   model_type TEXT,
                                   model_size DOUBLE PRECISION,
                                   metrics_elapsed_time DOUBLE PRECISION[],
                                   metrics_type TEXT[],
                                   loss_type TEXT,
                                   training_metrics_final DOUBLE PRECISION,
                                   training_loss_final DOUBLE PRECISION,
                                   training_metrics DOUBLE PRECISION[],
                                   training_loss DOUBLE PRECISION[],
                                   validation_metrics_final DOUBLE PRECISION,
                                   validation_loss_final DOUBLE PRECISION,
                                   validation_metrics DOUBLE PRECISION[],
                                   validation_loss DOUBLE PRECISION[],
                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
                                   """.format(self=self,
                                              ModelSelectionSchema=ModelSelectionSchema,
                                              ModelArchSchema=ModelArchSchema,
                                              AutoMLSchema=AutoMLConstants)
        plpy.execute(info_table_create_query)

    def update_model_selection_table(self):
        """
        Drops and re-create the mst table to only include the best performing model configuration.
        """
        drop_tables([self.model_selection_table])

        # only retaining best performing config
        plpy.execute("CREATE TABLE {self.model_selection_table} AS SELECT {ModelSelectionSchema.MST_KEY}, " \
                     "{ModelSelectionSchema.MODEL_ID}, {ModelSelectionSchema.COMPILE_PARAMS}, " \
                     "{ModelSelectionSchema.FIT_PARAMS} FROM {self.model_info_table} " \
                     "ORDER BY {AutoMLSchema.LOSS_METRIC} LIMIT 1".format(self=self,
                                                                          AutoMLSchema=AutoMLConstants,
                                                                          ModelSelectionSchema=ModelSelectionSchema))

    def generate_model_output_summary_table(self, model_training):
        """
        Creates and populates static values related to the AutoML workload.
        :param model_training: Fit Multiple function call object.
        """
        #TODO this code is duplicated in create_model_summary_table
        name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
        descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
        object_table = 'NULL' if self.object_table is None else '$MAD${0}$MAD$'.format(self.object_table)
        random_state = 'NULL' if self.random_state is None else '$MAD${0}$MAD$'.format(self.random_state)
        validation_table = 'NULL' if self.validation_table is None else '$MAD${0}$MAD$'.format(self.validation_table)

        create_query = plpy.prepare("""
                CREATE TABLE {self.model_summary_table} AS
                SELECT
                    $MAD${self.source_table}$MAD$::TEXT AS source_table,
                    {validation_table}::TEXT AS validation_table,
                    $MAD${self.model_output_table}$MAD$::TEXT AS model,
                    $MAD${self.model_info_table}$MAD$::TEXT AS model_info,
                    (SELECT dependent_varname FROM {model_training.model_summary_table})
                    AS dependent_varname,
                    (SELECT independent_varname FROM {model_training.model_summary_table})
                    AS independent_varname,
                    $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
                    $MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table,
                    $MAD${self.automl_method}$MAD$::TEXT AS automl_method,
                    $MAD${self.automl_params}$MAD$::TEXT AS automl_params,
                    {random_state}::TEXT AS random_state,
                    {object_table}::TEXT AS object_table,
                    {self.use_gpus} AS use_gpus,
                    (SELECT metrics_compute_frequency FROM {model_training.model_summary_table})::INTEGER
                    AS metrics_compute_frequency,
                    {name}::TEXT AS name,
                    {descr}::TEXT AS description,
                    '{self.start_training_time}'::TIMESTAMP AS start_training_time,
                    '{self.end_training_time}'::TIMESTAMP AS end_training_time,
                    (SELECT madlib_version FROM {model_training.model_summary_table}) AS madlib_version,
                    (SELECT num_classes FROM {model_training.model_summary_table})::INTEGER AS num_classes,
                    (SELECT class_values FROM {model_training.model_summary_table}) AS class_values,
                    (SELECT dependent_vartype FROM {model_training.model_summary_table})
                    AS dependent_vartype,
                    (SELECT normalizing_const FROM {model_training.model_summary_table})
                    AS normalizing_const
            """.format(self=self,
                       validation_table=validation_table,
                       random_state=random_state,
                       object_table=object_table,
                       name=name,
                       descr=descr,
                       model_training=model_training))

        plpy.execute(create_query)

    def is_automl_method(self, method_name):
        """
        Utility function to check automl method name.
        :param method_name: name of chosen method name to check.
        :return: boolean
        """
        return self.automl_method.lower() == method_name.lower()

    def _is_valid_metrics_compute_frequency(self, num_iterations):
        """
        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
        metrics during an AutoML algorithm run.
        :param num_iterations: interations/resources to allocate for training.
        :return: boolean on validity of the mcf value.
        """
        return self.metrics_compute_frequency is None or \
               (self.metrics_compute_frequency >= 1 and \
                self.metrics_compute_frequency <= num_iterations)

    def print_best_mst_so_far(self):
        """
        Prints mst keys with best train/val losses at a given point.
        """
        best_so_far = '\n'
        best_so_far += self.print_best_helper('training')
        if self.validation_table:
            best_so_far += self.print_best_helper('validation')
        plpy.info(best_so_far)

    def print_best_helper(self, keyword):
        """
        Helper function to Prints mst keys with best train/val losses at a given point.
        :param keyword: column prefix ('training' or 'validation')
        :return:
        """
        metrics_word, loss_word = keyword + '_metrics_final', keyword + '_loss_final'

        res_str = 'Best {keyword} loss so far:\n'.format(keyword=keyword)
        best_value = plpy.execute("SELECT {ModelSelectionSchema.MST_KEY}, {metrics_word}, " \
                                  "{loss_word} FROM {self.model_info_table} ORDER BY " \
                                  "{loss_word} LIMIT 1".format(self=self, ModelSelectionSchema=ModelSelectionSchema,
                                                               metrics_word=metrics_word, loss_word=loss_word))[0]
        mst_key_value, metric_value, loss_value = best_value[ModelSelectionSchema.MST_KEY], \
                                                  best_value[metrics_word], best_value[loss_word]
        res_str += ModelSelectionSchema.MST_KEY + '=' + str(mst_key_value) + ': metric=' + str(metric_value) + \
                   ', loss=' + str(loss_value) + '\n'
        return res_str

    def remove_temp_tables(self, model_training):
        """
        Remove all intermediate tables created for AutoML runs/updates.
        :param model_training: Fit Multiple function call object.
        """
        drop_tables([model_training.original_model_output_table, model_training.model_info_table,
                     model_training.model_summary_table, AutoMLConstants.TEMP_MST_TABLE,
                     AutoMLConstants.TEMP_MST_SUMMARY_TABLE])

class AutoMLHyperband(KerasAutoML):
    """
    This class implements Hyperband, an infinite-arm bandit based algorithm that speeds up random search
    through adaptive resource allocation, successive halving (SHA), and early stopping.

    This class showcases a novel hyperband implementation by executing the hyperband rounds 'diagonally'
    to evaluate multiple configurations together and leverage the compute power of MPP databases such as Greenplum.

    This automl method inherits qualities from the automl class.
    """
    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
                 model_id_list, compile_params_grid, fit_params_grid, automl_method,
                 automl_params, random_state=None, object_table=None,
                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
                 name=None, description=None, **kwargs):
        automl_method = automl_method if automl_method else AutoMLConstants.HYPERBAND
        automl_params = automl_params if automl_params else 'R=6, eta=3, skip_last=0'
        KerasAutoML.__init__(self, schema_madlib, source_table, model_output_table, model_arch_table,
                             model_selection_table, model_id_list, compile_params_grid, fit_params_grid,
                             automl_method, automl_params, random_state, object_table, use_gpus,
                             validation_table, metrics_compute_frequency, name, description, **kwargs)
        self.validate_and_define_inputs()
        self.create_model_output_table()
        self.create_model_output_info_table()
        self.find_hyperband_config()

    def validate_and_define_inputs(self):
        automl_params_dict = extract_keyvalue_params(self.automl_params,
                                                     lower_case_names=False)
        # casting dict values to int
        for i in automl_params_dict:
            automl_params_dict[i] = int(automl_params_dict[i])
        _assert(len(automl_params_dict) >= 1 and len(automl_params_dict) <= 3,
                "{0}: Only R, eta, and skip_last may be specified".format(self.module_name))
        for i in automl_params_dict:
            if i == AutoMLConstants.R:
                self.R = automl_params_dict[AutoMLConstants.R]
            elif i == AutoMLConstants.ETA:
                self.eta = automl_params_dict[AutoMLConstants.ETA]
            elif i == AutoMLConstants.SKIP_LAST:
                self.skip_last = automl_params_dict[AutoMLConstants.SKIP_LAST]
            else:
                plpy.error("{0}: {1} is an invalid automl param".format(self.module_name, i))
        _assert(self.eta > 1, "{0}: eta must be greater than 1".format(self.module_name))
        _assert(self.R >= self.eta, "{0}: R should not be less than eta".format(self.module_name))
        self.s_max = int(math.floor(math.log(self.R, self.eta)))
        _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "{0}: skip_last must be " \
                "non-negative and less than {1}".format(self.module_name, self.s_max))

    def find_hyperband_config(self):
        """
        Executes the diagonal hyperband algorithm.
        """
        initial_vals = {}

        # get hyper parameter configs for each s
        for s in reversed(range(self.s_max+1)):
            n = int(math.ceil(int((self.s_max+1)/(s+1))*math.pow(self.eta, s))) # initial number of configurations
            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
            initial_vals[s] = (n, int(round(r)))
        self.start_training_time = self.get_current_timestamp()
        self.start_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT)
        random_search = MstSearch(self.schema_madlib,
                                  self.model_arch_table,
                                  self.model_selection_table,
                                  self.model_id_list,
                                  self.compile_params_grid,
                                  self.fit_params_grid,
                                  'random',
                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]),
                                  self.random_state,
                                  self.object_table)
        random_search.load() # for populating mst tables

        # for creating the summary table for usage in fit multiple
        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
                     "SELECT * FROM {random_search.model_selection_summary_table}".format(AutoMLSchema=AutoMLConstants,
                                                                                          random_search=random_search))
        ranges_dict = self.mst_key_ranges_dict(initial_vals)
        # to store the bracket and round numbers
        s_dict, i_dict = {}, {}
        for key, val in ranges_dict.items():
            for mst_key in range(val[0], val[1]+1):
                s_dict[mst_key] = key
                i_dict[mst_key] = -1

        # outer loop on diagonal
        for i in range((self.s_max+1) - int(self.skip_last)):
            # inner loop on s desc
            temp_lst = []
            configs_prune_lookup = {}
            for s in range(self.s_max, self.s_max-i-1, -1):
                n = initial_vals[s][0]
                n_i = n * math.pow(self.eta, -i+self.s_max-s)
                configs_prune_lookup[s] = int(round(n_i))
                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
            num_iterations = int(initial_vals[self.s_max-i][1])
            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
                num_iterations))

            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup) # has keys to evaluate
            active_keys = plpy.execute("SELECT {ModelSelectionSchema.MST_KEY} " \
                                       "FROM {AutoMLSchema.TEMP_MST_TABLE}".format(AutoMLSchema=AutoMLConstants,
                                                                                   ModelSelectionSchema=ModelSelectionSchema))
            for k in active_keys:
                i_dict[k[ModelSelectionSchema.MST_KEY]] += 1
            self.warm_start = int(i != 0)
            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
            with SetGUC("plan_cache_mode", "force_generic_plan"):
                model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
            self.update_model_output_table(model_training)
            self.update_model_output_info_table(i, model_training, initial_vals)

            self.print_best_mst_so_far()

        self.end_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT)
        self.add_additional_info_cols(s_dict, i_dict)
        self.update_model_selection_table()
        self.generate_model_output_summary_table(model_training)
        self.remove_temp_tables(model_training)
        # cleanup_madlib_temp_tables(self.schema_madlib, AutoMLSchema.TARGET_SCHEMA)

    def mst_key_ranges_dict(self, initial_vals):
        """
        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
        executing a particular SHA bracket.
        """
        d = {}
        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
            if s_val == self.s_max:
                d[s_val] = (1, initial_vals[s_val][0])
            else:
                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
        return d

    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
        """
        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
        :param i: outer diagonal loop iteration.
        :param ranges_dict: model config ranges to group by bracket number.
        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
        :return:
        """
        if i == 0:
            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
            lower_bound, upper_bound = ranges_dict[self.s_max]
            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
                         "WHERE {ModelSelectionSchema.MST_KEY} >= {lower_bound} " \
                         "AND {ModelSelectionSchema.MST_KEY} <= {upper_bound}".format(self=self,
                                                                                      AutoMLSchema=AutoMLConstants,
                                                                                      lower_bound=lower_bound,
                                                                                      upper_bound=upper_bound,
                                                                                      ModelSelectionSchema=ModelSelectionSchema))
            return
        # dropping and repopulating temp_mst_table
        drop_tables([AutoMLConstants.TEMP_MST_TABLE])

        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
        create_query = """
                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
                            {mst_key} INTEGER,
                            {model_id} INTEGER,
                            {compile_params} VARCHAR,
                            {fit_params} VARCHAR,
                            unique ({model_id}, {compile_params}, {fit_params})
                        );
                       """.format(AutoMLSchema=AutoMLConstants,
                                  mst_key=ModelSelectionSchema.MST_KEY,
                                  model_id=ModelSelectionSchema.MODEL_ID,
                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
        plpy.execute(create_query)

        query = ""
        new_configs = True
        for s_val in configs_prune_lookup:
            lower_bound, upper_bound = ranges_dict[s_val]
            if new_configs:
                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT {ModelSelectionSchema.MST_KEY}, " \
                         "{ModelSelectionSchema.MODEL_ID}, {ModelSelectionSchema.COMPILE_PARAMS}, " \
                         "{ModelSelectionSchema.FIT_PARAMS} FROM {self.model_selection_table} WHERE " \
                         "{ModelSelectionSchema.MST_KEY} >= {lower_bound} AND {ModelSelectionSchema.MST_KEY} <= " \
                         "{upper_bound};".format(self=self, AutoMLSchema=AutoMLConstants,
                                                 ModelSelectionSchema=ModelSelectionSchema,
                                                 lower_bound=lower_bound, upper_bound=upper_bound)
                new_configs = False
            else:
                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT {ModelSelectionSchema.MST_KEY}, " \
                         "{ModelSelectionSchema.MODEL_ID}, {ModelSelectionSchema.COMPILE_PARAMS}, " \
                         "{ModelSelectionSchema.FIT_PARAMS} " \
                         "FROM {self.model_info_table} WHERE {ModelSelectionSchema.MST_KEY} >= {lower_bound} " \
                         "AND {ModelSelectionSchema.MST_KEY} <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLConstants,
                                                                    ModelSelectionSchema=ModelSelectionSchema,
                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
        plpy.execute(query)

    def update_model_output_table(self, model_training):
        """
        Updates gathered information of a hyperband diagonal run to the overall model output table.
        :param model_training: Fit Multiple function call object.
        """
        # updates model weights for any previously trained configs
        plpy.execute("UPDATE {self.model_output_table} a SET model_weights=" \
                     "t.model_weights FROM {model_training.original_model_output_table} t " \
                     "WHERE a.mst_key=t.mst_key".format(self=self, model_training=model_training))

        # truncate and re-creates table to avoid memory blow-ups
        with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
            temp_model_table = unique_string('updated_model')
            plpy.execute("CREATE TABLE {temp_model_table} AS SELECT * FROM {self.model_output_table};" \
                         "TRUNCATE {self.model_output_table}; " \
                         "DROP TABLE {self.model_output_table};".format(temp_model_table=temp_model_table, self=self))
            rename_table(self.schema_madlib, temp_model_table, self.model_output_table)

        # inserts any newly trained configs
        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM {model_training.original_model_output_table} " \
                     "WHERE {model_training.original_model_output_table}.mst_key NOT IN " \
                     "(SELECT {ModelSelectionSchema.MST_KEY} FROM {self.model_output_table})".format(self=self,
                                                                              model_training=model_training,
                                                                              ModelSelectionSchema=ModelSelectionSchema))

    def update_model_output_info_table(self, i, model_training, initial_vals):
        """
        Updates gathered information of a hyperband diagonal run to the overall model output info table.
        :param i: outer diagonal loop iteration.
        :param model_training: Fit Multiple function call object.
        :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband
        schedule.
        """
        # normalizing factor for metrics_iters due to warm start
        epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i]) # i & initial_vals args needed
        iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
                             "FROM {model_training.model_summary_table}".format(AutoMLSchema=AutoMLConstants,
                                                                                model_training=model_training))
        metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']] # global iteration counter

        validation_update_q = "validation_metrics_final=t.validation_metrics_final, " \
                                     "validation_loss_final=t.validation_loss_final, " \
                                     "validation_metrics=a.validation_metrics || t.validation_metrics, " \
                                     "validation_loss=a.validation_loss || t.validation_loss, " \
            if self.validation_table else ""

        # updates train/val info for any previously trained configs
        plpy.execute("UPDATE {self.model_info_table} a SET " \
                     "metrics_elapsed_time=a.metrics_elapsed_time || t.metrics_elapsed_time, " \
                     "training_metrics_final=t.training_metrics_final, " \
                     "training_loss_final=t.training_loss_final, " \
                     "training_metrics=a.training_metrics || t.training_metrics, " \
                     "training_loss=a.training_loss || t.training_loss, ".format(self=self) + validation_update_q +
                     "{AutoMLSchema.METRICS_ITERS}=a.metrics_iters || ARRAY{metrics_iters_val}::INTEGER[] " \
                     "FROM {model_training.model_info_table} t " \
                     "WHERE a.mst_key=t.mst_key".format(model_training=model_training, AutoMLSchema=AutoMLConstants,
                                                        metrics_iters_val=metrics_iters_val))

        # inserts info about metrics and validation for newly trained model configs
        plpy.execute("INSERT INTO {self.model_info_table} SELECT t.*, ARRAY{metrics_iters_val}::INTEGER[] AS metrics_iters " \
                     "FROM {model_training.model_info_table} t WHERE t.mst_key NOT IN " \
                     "(SELECT {ModelSelectionSchema.MST_KEY} FROM {self.model_info_table})".format(self=self,
                                                                            model_training=model_training,
                                                                            metrics_iters_val=metrics_iters_val,
                                                                            ModelSelectionSchema=ModelSelectionSchema))

    def add_additional_info_cols(self, s_dict, i_dict):
        """Adds s and i columns to the info table"""

        plpy.execute("ALTER TABLE {self.model_info_table} ADD COLUMN s int, ADD COLUMN i int;".format(self=self))

        l = [(k, s_dict[k], i_dict[k]) for k in s_dict]
        query = "UPDATE {self.model_info_table} t SET s=b.s_val, i=b.i_val FROM unnest(ARRAY{l}) " \
                "b (key integer, s_val integer, i_val integer) WHERE t.mst_key=b.key".format(self=self, l=l)
        plpy.execute(query)

class AutoMLHyperopt(KerasAutoML):
    """
    This class implements Hyperopt, another automl method that explores awkward search spaces using
    Random Search, Tree-structured Parzen Estimator (TPE), or Adaptive TPE.

    This function executes hyperopt on top of our multiple model training infrastructure powered with
    Model hOpper Parallelism (MOP), a hybrid of data and task parallelism.

    This automl method inherits qualities from the automl class.
    """
    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
                 model_id_list, compile_params_grid, fit_params_grid, automl_method,
                 automl_params, random_state=None, object_table=None,
                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
                 name=None, description=None, **kwargs):
        automl_method = automl_method if automl_method else AutoMLConstants.HYPEROPT
        automl_params = automl_params if automl_params else 'num_configs=20, num_iterations=5, algorithm=tpe'
        KerasAutoML.__init__(self, schema_madlib, source_table, model_output_table, model_arch_table,
                             model_selection_table, model_id_list, compile_params_grid, fit_params_grid,
                             automl_method, automl_params, random_state, object_table, use_gpus,
                             validation_table, metrics_compute_frequency, name, description, **kwargs)
        self.compile_params_grid = self.compile_params_grid.replace('\n', '').replace(' ', '')
        self.fit_params_grid = self.fit_params_grid.replace('\n', '').replace(' ', '')
        try:
            self.compile_params_grid = literal_eval(self.compile_params_grid)

        except:
            plpy.error("Invalid syntax in 'compile_params_dict'")
        try:
            self.fit_params_grid = literal_eval(self.fit_params_grid)
        except:
            plpy.error("Invalid syntax in 'fit_params_dict'")
        self.validate_and_define_inputs()
        self.num_segments = self.get_num_segments()

        self.create_model_output_table()
        self.create_model_output_info_table()
        self.find_hyperopt_config()

    def get_num_segments(self):
        """
        # query dist rules from summary table to get the total no of segments
        :return:
        """
        source_summary_table = add_postfix(self.source_table, '_summary')
        dist_rules = plpy.execute("SELECT {0} from {1}".format(DISTRIBUTION_RULES, source_summary_table))[0][DISTRIBUTION_RULES]
        #TODO create constant for all_segments
        if dist_rules == "all_segments":
            return get_seg_number()

        return len(dist_rules)

    def validate_and_define_inputs(self):
        automl_params_dict = extract_keyvalue_params(self.automl_params,
                                                     lower_case_names=True)
        # casting relevant values to int
        for i in automl_params_dict:
            try:
                automl_params_dict[i] = int(automl_params_dict[i])
            except ValueError:
                pass
        _assert(len(automl_params_dict) >= 1 and len(automl_params_dict) <= 3,
                "{0}: Only num_configs, num_iterations, and algorithm may be specified".format(self.module_name))
        for i in automl_params_dict:
            if i == AutoMLConstants.NUM_CONFIGS:
                self.num_configs = automl_params_dict[AutoMLConstants.NUM_CONFIGS]
            elif i == AutoMLConstants.NUM_ITERS:
                self.num_iters = automl_params_dict[AutoMLConstants.NUM_ITERS]
            elif i == AutoMLConstants.ALGORITHM:
                if automl_params_dict[AutoMLConstants.ALGORITHM].lower() == 'rand':
                    self.algorithm = rand
                elif automl_params_dict[AutoMLConstants.ALGORITHM].lower() == 'tpe':
                    self.algorithm = tpe
                # elif automl_params_dict[AutoMLSchema.ALGORITHM].lower() == 'atpe':
                #     self.algorithm = atpe
                # uncomment the above lines after atpe works # TODO
                else:
                    plpy.error("{0}: valid algorithm 'automl_params' for hyperopt: 'rand', 'tpe'".format(self.module_name)) # , or 'atpe'
            else:
                plpy.error("{0}: {1} is an invalid automl param".format(self.module_name, i))
        _assert(self.num_configs > 0 and self.num_iters > 0, "{0}: num_configs and num_iterations in 'automl_params' "
                                                            "must be > 0".format(self.module_name))
        _assert(self._is_valid_metrics_compute_frequency(self.num_iters), "{0}: 'metrics_compute_frequency' "
                                                                          "out of iteration range".format(self.module_name))

    def find_hyperopt_config(self):
        """
        Executes hyperopt on top of MOP.
        """
        make_mst_summary = True
        trials = Trials()
        domain = Domain(None, self.get_search_space())
        rand_state = np.random.RandomState(self.random_state)
        configs_lst = self.get_configs_list(self.num_configs, self.num_segments)

        self.start_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT)
        fit_multiple_runtime = 0
        for low, high in configs_lst:
            i, n = low, high - low + 1

            # Using HyperOpt TPE/ATPE to generate parameters
            hyperopt_params = []
            sampled_params = []
            for j in range(i, i + n):
                new_param = self.algorithm.suggest([j], domain, trials, rand_state.randint(0, AutoMLConstants.INT_MAX))
                new_param[0]['status'] = STATUS_RUNNING

                trials.insert_trial_docs(new_param)
                trials.refresh()
                hyperopt_params.append(new_param[0])
                sampled_params.append(new_param[0]['misc']['vals'])

            model_id_list, compile_params, fit_params = self.extract_param_vals(sampled_params)
            msts_list = self.generate_msts(model_id_list, compile_params, fit_params)
            # cleanup_madlib_temp_tables(self.schema_madlib, AutoMLSchema.TARGET_SCHEMA)
            try:
                self.remove_temp_tables(model_training)
            except:
                pass
            self.populate_temp_mst_tables(i, msts_list)

            plpy.info("***Evaluating {n} newly suggested model configurations***".format(n=n))
            fit_multiple_start_time = time.time()
            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLConstants.TEMP_OUTPUT_TABLE,
                                              AutoMLConstants.TEMP_MST_TABLE, self.num_iters, self.use_gpus, self.validation_table,
                                              self.metrics_compute_frequency, False, self.name, self.description, fit_multiple_runtime)
            fit_multiple_runtime += time.time() - fit_multiple_start_time
            if make_mst_summary:
                self.generate_mst_summary_table(self.model_selection_summary_table)
                make_mst_summary = False

            # HyperOpt TPE update
            for k, hyperopt_param in enumerate(hyperopt_params, i):
                loss_val = plpy.execute("SELECT {AutoMLSchema.LOSS_METRIC} FROM {model_training.model_info_table} " \
                             "WHERE {ModelSelectionSchema.MST_KEY}={k}".format(AutoMLSchema=AutoMLConstants,
                                                                               ModelSelectionSchema=ModelSelectionSchema,
                                                                               **locals()))[0][AutoMLConstants.LOSS_METRIC]

                # avoid removing the two lines below (part of Hyperopt updates)
                hyperopt_param['status'] = STATUS_OK
                hyperopt_param['result'] = {'loss': loss_val, 'status': STATUS_OK}
            trials.refresh()

            # stacks info of all model configs together
            self.update_model_output_and_info_tables(model_training)

            self.print_best_mst_so_far()

        self.end_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT)
        self.update_model_selection_table()
        self.generate_model_output_summary_table(model_training)
        # cleanup_madlib_temp_tables(self.schema_madlib, AutoMLSchema.TARGET_SCHEMA)
        self.remove_temp_tables(model_training)

    def get_configs_list(self, num_configs, num_segments):
        """
        Gets schedule to evaluate model configs
        :return: Model configs evaluation schedule
        """
        num_buckets = int(round(float(num_configs) / num_segments))
        configs_list = []
        start_idx = 1
        models_populated = 0
        for _ in range(num_buckets - 1):
            end_idx = start_idx + num_segments
            models_populated += num_segments
            configs_list.append((start_idx, end_idx - 1))
            start_idx = end_idx

        remaining_models = num_configs - models_populated
        configs_list.append((start_idx, start_idx + remaining_models-1))

        return configs_list

    def get_search_space(self):
        """
        Converts user inputs to hyperopt search space.
        :return: Hyperopt search space
        """

        # initial params (outside 'optimizer_params_list')
        hyperopt_search_dict = {}
        hyperopt_search_dict['model_id'] = self.get_hyperopt_exps('model_id', self.model_id_list)


        for j in self.fit_params_grid:
            hyperopt_search_dict[j] = self.get_hyperopt_exps(j, self.fit_params_grid[j])

        for i in self.compile_params_grid:
            if i != ModelSelectionSchema.OPTIMIZER_PARAMS_LIST:
                hyperopt_search_dict[i] = self.get_hyperopt_exps(i, self.compile_params_grid[i])

        hyperopt_search_space_lst = []

        counter = 1 # for unique names to allow multiple distribution options for optimizer params
        for optimizer_dict in self.compile_params_grid[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST]:
            for o_param in optimizer_dict:
                name = o_param + '_' + str(counter)
                hyperopt_search_dict[name] = self.get_hyperopt_exps(name, optimizer_dict[o_param])
            # appending deep copy
            hyperopt_search_space_lst.append({k:v for k, v in hyperopt_search_dict.items()})
            for o_param in optimizer_dict:
                name = o_param + '_' + str(counter)
                del hyperopt_search_dict[name]
            counter += 1

        return hp.choice('space', hyperopt_search_space_lst)

    def get_hyperopt_exps(self, cp, param_value_list):
        """
        Samples a value from a given list of values, either randomly from a list of discrete elements,
        or from a specified distribution.
        :param cp: compile param
        :param param_value_list: list of values (or specified distribution) for a param
        :return: sampled value
        """
        # check if need to sample from a distribution
        if type(param_value_list[-1]) == str and all([type(i) != str and not callable(i) for i in param_value_list[:-1]]) \
                and len(param_value_list) > 1:
            _assert_equal(len(param_value_list), 3,
                          "{0}: '{1}' should have exactly 3 elements if picking from a distribution".format(self.module_name, cp))
            _assert(param_value_list[1] > param_value_list[0],
                    "{0}: '{1}' should be of the format [lower_bound, upper_bound, distribution_type]".format(self.module_name, cp))
            if param_value_list[-1] == 'linear':
                return hp.uniform(cp, param_value_list[0], param_value_list[1])
            elif param_value_list[-1] == 'log':
                return hp.loguniform(cp, np.log(param_value_list[0]), np.log(param_value_list[1]))
            else:
                plpy.error("{0}: Please choose a valid distribution type for '{1}': {2}".format(
                    self.module_name,
                    self.original_param_details(cp)[0],
                    ['linear', 'log']))
        else:
            # random sampling
            return hp.choice(cp, param_value_list)

    def extract_param_vals(self, sampled_params):
        """
        Extract parameter values from hyperopt search space.
        :param sampled_params: params suggested by hyperopt.
        :return: lists of model ids, compile and fit params.
        """
        model_id_list, compile_params, fit_params = [], [], []
        for params_dict in sampled_params:
            compile_dict, fit_dict, optimizer_params_dict = {}, {}, {}
            for p in params_dict:
                if len(params_dict[p]) == 0 or p == 'space':
                    continue
                val = params_dict[p][0]
                if p == 'model_id':
                    model_id_list.append(self.model_id_list[val])
                    continue
                elif p in self.fit_params_grid:
                    try:
                        # check if params_dict[p] is an index
                        fit_dict[p] = self.fit_params_grid[p][val]
                    except TypeError:
                        fit_dict[p] = params_dict[p]
                elif p in self.compile_params_grid:
                    try:
                        # check if params_dict[p] is an index
                        compile_dict[p] = self.compile_params_grid[p][val]
                    except TypeError:
                        compile_dict[p] = val
                else:
                    o_param, idx = self.original_param_details(p) # extracting unique attribute
                    try:
                        # check if params_dict[p] is an index (i.e. optimizer, for example)
                        optimizer_params_dict[o_param] = self.compile_params_grid[
                            ModelSelectionSchema.OPTIMIZER_PARAMS_LIST][idx][o_param][val]
                    except TypeError:
                        optimizer_params_dict[o_param] = val
            compile_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST] = optimizer_params_dict

            compile_params.append(compile_dict)
            fit_params.append(fit_dict)

        return model_id_list, compile_params, fit_params

    def original_param_details(self, name):
        """
        Returns the original param name and book-keeping detail.
        :param name: name of the param (example - lr_1, epsilon_12)
        :return: original param name and book-keeping position.
        """
        parts = name.split('_')
        return '_'.join(parts[:-1]), int(parts[-1]) - 1


    def generate_msts(self, model_id_list, compile_params, fit_params):
        """
        Generates msts to insert in the mst table.
        :param model_id_list: list of model ids
        :param compile_params: list compile params
        :param fit_params:list of fit params
        :return: List of msts to insert in the mst table.
        """
        assert len(model_id_list) == len(compile_params) == len(fit_params)
        msts = []

        for i in range(len(compile_params)):
            combination = {}
            combination[ModelSelectionSchema.MODEL_ID] = model_id_list[i]
            combination[ModelSelectionSchema.COMPILE_PARAMS] = generate_row_string(compile_params[i])
            combination[ModelSelectionSchema.FIT_PARAMS] = generate_row_string(fit_params[i])
            msts.append(combination)

        return msts

    def populate_temp_mst_tables(self, i, msts_list):
        """
        Creates and populates temp mst and summary tables with newly suggested model configs for evaluation.
        :param i: mst key number
        :param msts_list: list of generated msts.
        """
        # extra sanity check
        if table_exists(AutoMLConstants.TEMP_MST_TABLE):
            drop_tables([AutoMLConstants.TEMP_MST_TABLE])

        create_query = """
                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
                            {mst_key} INTEGER,
                            {model_id} INTEGER,
                            {compile_params} VARCHAR,
                            {fit_params} VARCHAR,
                            unique ({model_id}, {compile_params}, {fit_params})
                        );
                       """.format(AutoMLSchema=AutoMLConstants,
                                  mst_key=ModelSelectionSchema.MST_KEY,
                                  model_id=ModelSelectionSchema.MODEL_ID,
                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
        plpy.execute(create_query)
        mst_key_val = i
        for mst in msts_list:
            model_id = mst[ModelSelectionSchema.MODEL_ID]
            compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS]
            fit_params = mst[ModelSelectionSchema.FIT_PARAMS]
            insert_query = """
                            INSERT INTO
                                {AutoMLSchema.TEMP_MST_TABLE}(
                                    {mst_key_col},
                                    {model_id_col},
                                    {compile_params_col},
                                    {fit_params_col}
                                )
                            VALUES (
                                {mst_key_val},
                                {model_id},
                                $${compile_params}$$,
                                $${fit_params}$$
                            )
                           """.format(mst_key_col=ModelSelectionSchema.MST_KEY,
                                      model_id_col=ModelSelectionSchema.MODEL_ID,
                                      compile_params_col=ModelSelectionSchema.COMPILE_PARAMS,
                                      fit_params_col=ModelSelectionSchema.FIT_PARAMS,
                                      AutoMLSchema=AutoMLConstants,
                                      **locals())
            mst_key_val += 1
            plpy.execute(insert_query)

        self.generate_mst_summary_table(AutoMLConstants.TEMP_MST_SUMMARY_TABLE)

    def generate_mst_summary_table(self, tbl_name):
        """
        generates mst summary table with the given name
        :param tbl_name: name of summary table
        """
        _assert(tbl_name.endswith('_summary'), 'invalid summary table name')

        # extra sanity check
        if table_exists(tbl_name):
            drop_tables([tbl_name])

        create_query = """
                        CREATE TABLE {tbl_name} (
                            {model_arch_table} VARCHAR,
                            {object_table} VARCHAR
                        );
                       """.format(tbl_name=tbl_name,
                                  model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
                                  object_table=ModelSelectionSchema.OBJECT_TABLE)
        plpy.execute(create_query)

        if self.object_table is None:
            object_table = 'NULL::VARCHAR'
        else:
            object_table = '$${0}$$'.format(self.object_table)
        insert_summary_query = """
                        INSERT INTO
                            {tbl_name}(
                                {model_arch_table_name},
                                {object_table_name}
                        )
                        VALUES (
                            $${self.model_arch_table}$$,
                            {object_table}
                        )
                       """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
                                  object_table_name=ModelSelectionSchema.OBJECT_TABLE,
                                  **locals())
        plpy.execute(insert_summary_query)

    def update_model_output_and_info_tables(self, model_training):
        """
        Updates model output and info tables by stacking rows after each evaluation round.
        :param model_training: Fit Multiple class object
        """
        metrics_iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
                                     "FROM {model_training.original_model_output_table}_summary".format(self=self,
                                                                                                        model_training=model_training,
                                                                                                        AutoMLSchema=AutoMLConstants))[0][AutoMLConstants.METRICS_ITERS]
        if metrics_iters:
            metrics_iters = "ARRAY{0}".format(metrics_iters)
        # stacking new rows from training
        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM " \
                     "{model_training.original_model_output_table}".format(self=self, model_training=model_training))
        plpy.execute("INSERT INTO {self.model_info_table} SELECT *, {metrics_iters} FROM " \
                     "{model_training.model_info_table}".format(self=self,
                                                                     model_training=model_training,
                                                                     metrics_iters=metrics_iters))
