# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import plpy
import time
import sys
import json
import random
import datetime
from collections import defaultdict
# from tensorflow.keras.models import *

from madlib_keras import compute_loss_and_metrics
from madlib_keras import get_model_arch
from madlib_keras import get_source_summary_table_dict
from madlib_keras import should_compute_metrics_this_iter
from madlib_keras import get_initial_weights
from madlib_keras_helper import *
from madlib_keras_model_selection import ModelSelectionSchema
from madlib_keras_validator import *
from madlib_keras_wrapper import *

from internal.db_utils import quote_literal
from utilities.control import OptimizerControl
from utilities.control import SetGUC
from utilities.utilities import add_postfix
from utilities.utilities import is_platform_gp6_or_up
from utilities.utilities import unique_string
from utilities.utilities import rotate
from utilities.utilities import madlib_version
from utilities.utilities import is_platform_pg
from utilities.utilities import get_seg_number
import utilities.debug as DEBUG
from utilities.debug import plpy_prepare
from utilities.debug import plpy_execute

DEBUG.timings_enabled = False
DEBUG.mst_keys_enabled = False
DEBUG.plpy_execute_enabled = False
DEBUG.plpy_info_enabled = False


"""
FitMultipleModel: This class implements the Model Hopper technique for
training multiple models in parallel. The goal of this function is to train
multiple different models on the same data with different parameters.
The main advantage of this method over running the existing fit function in a
loop is avoiding inaccuracies caused by the model averaging. The basic idea of
model hopper is simple. Let's assume that there are n segments and c*n model
configurations. We begin with distributing these configs to segments. After
that, each segment trains their c models on the data they have locally for one
iteration. Once we have these models, we move them to a different segment
(hopping) as well as receive a different set of models. Once we have the new
models, we use our segments data to refine them (similar to the warm start
functionality). Once every model hops through every segment, we consider an
iteration complete.

This method ensures that we don't have to average any model and the
loss&accuracy is very close to the ideal case, where all of the data is in one
segment.

Note that this function is disabled for Postgres.
"""

class FitMultipleModel(object):
    def __init__(self, schema_madlib, source_table, model_output_table,
                 model_selection_table, num_iterations,
                 use_gpus=False, validation_table=None,
                 metrics_compute_frequency=None, warm_start=False, name="",
                 description="", use_caching=False, metrics_elapsed_time_offset=0, **kwargs):
        """

        :param schema_madlib: schema name
        :param source_table: input table containing training dataset
        :param model_output_table: output table
        :param model_selection_table: input table containing model configs
        :param num_iterations: number of iterations to train
        :param use_gpus: determines whether GPUs are to be used for training
        :param validation_table: input table containing the validation dataset
        :param metrics_compute_frequency: Frequency to compute per-iteration metrics
                                          for the training dataset and validation dataset
        :param warm_start: indicates whether to initialize weights with the coefficients
                           from the last call of the fit function
        :param name: name
        :param description: description
        :param metrics_elapsed_time_offset: time elapsed for the previous call to fit_multiple
                                            (internal param used by automl to accumulate
                                             metrics_elapsed_time)
        """
        # set the random seed for visit order/scheduling
        random.seed(1)
        if is_platform_pg():
            plpy.error(
                "DL: Multiple model training is not supported on PostgreSQL.")
        self.source_table = source_table
        self.validation_table = validation_table
        self.model_selection_table = model_selection_table
        if self.model_selection_table:
            self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')

        self.dist_key_col = DISTRIBUTION_KEY_COLNAME
        self.prev_dist_key_col = '__prev_dist_key__'
        self.num_iterations = num_iterations
        self.metrics_compute_frequency = metrics_compute_frequency
        self.name = name
        self.description = description
        self.use_caching = use_caching if use_caching is not None else False
        self.module_name = 'madlib_keras_fit_multiple_model'
        self.schema_madlib = schema_madlib
        self.version = madlib_version(self.schema_madlib)
        self.mst_key_col = ModelSelectionSchema.MST_KEY
        self.model_id_col = ModelSelectionSchema.MODEL_ID
        self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
        self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
        self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE
        self.model_weights_col = ModelArchSchema.MODEL_WEIGHTS
        self.model_arch_col = ModelArchSchema.MODEL_ARCH
        self.train_mst_metric_eval_time = defaultdict(list)
        self.train_mst_loss = defaultdict(list)
        self.train_mst_metric = defaultdict(list)
        self.info_str = ""
        source_summary_table = add_postfix(self.source_table, "_summary")
        input_tbl_valid(source_summary_table, self.module_name)
        src_summary_dict = get_source_summary_table_dict(source_summary_table)

        self.mb_dep_var_cols = src_summary_dict['dependent_varname']
        self.mb_indep_var_cols = src_summary_dict['independent_varname']
        self.dep_shape_cols = [add_postfix(i, "_shape") for i in self.mb_dep_var_cols]
        self.ind_shape_cols = [add_postfix(i, "_shape") for i in self.mb_indep_var_cols]

        self.columns_dict = {}
        self.columns_dict['mb_dep_var_cols'] = self.mb_dep_var_cols
        self.columns_dict['mb_indep_var_cols'] = self.mb_indep_var_cols
        self.columns_dict['dep_shape_cols'] = self.dep_shape_cols
        self.columns_dict['ind_shape_cols'] = self.ind_shape_cols

        self.val_dep_var = None
        self.val_ind_var = None
        self.val_dep_shape_cols = None
        self.val_ind_shape_cols = None
        if validation_table:
            validation_summary_table = add_postfix(self.validation_table, "_summary")
            input_tbl_valid(validation_summary_table, self.module_name)
            val_summary_dict = get_source_summary_table_dict(validation_summary_table)

            self.val_dep_var = val_summary_dict['dependent_varname']
            self.val_ind_var = val_summary_dict['independent_varname']
            self.val_dep_shape_cols = [add_postfix(i, "_shape") for i in self.val_dep_var]
            self.val_ind_shape_cols = [add_postfix(i, "_shape") for i in self.val_ind_var]

        self.columns_dict['val_dep_var'] = self.val_dep_var
        self.columns_dict['val_ind_var'] = self.val_ind_var
        self.columns_dict['val_dep_shape_cols'] = self.val_dep_shape_cols
        self.columns_dict['val_ind_shape_cols'] = self.val_ind_shape_cols

        self.use_gpus = use_gpus if use_gpus else False
        self.model_input_tbl = unique_string('model_input')
        self.model_output_tbl = unique_string('model_output')
        self.schedule_tbl = unique_string('schedule')
        self.next_schedule_tbl = unique_string('next_schedule')
        self.cached_source_table = unique_string('cached_source_table')
        self.metrics_elapsed_time_offset = metrics_elapsed_time_offset
        self.rotate_schedule_tbl_plan = self.add_object_maps_plan = None
        self.hop_plan = self.udf_plan = None

        self.segments_per_host = get_data_distribution_per_segment(source_table)
        if self.use_gpus:
            self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                self.schema_madlib, self.segments_per_host, self.module_name)
        else:
            self.accessible_gpus_for_seg = get_seg_number()*[0]

        self.original_model_output_tbl = model_output_table
        if not self.original_model_output_tbl:
            plpy.error("Must specify an output table.")

        self.model_info_tbl = add_postfix(
            self.original_model_output_tbl, '_info')
        self.model_summary_table = add_postfix(
            self.original_model_output_tbl, '_summary')

        self.warm_start = bool(warm_start)

        self.fit_validator_train = FitMultipleInputValidator(
            self.source_table, self.validation_table, self.original_model_output_tbl,
            self.model_selection_table, self.model_selection_summary_table,
            self.mb_dep_var_cols, self.mb_indep_var_cols, self.dep_shape_cols,
            self.ind_shape_cols, self.num_iterations,
            self.model_info_tbl, self.mst_key_col, self.model_arch_table_col,
            self.metrics_compute_frequency, self.warm_start, self.use_gpus,
            self.accessible_gpus_for_seg, self.val_dep_var, self.val_ind_var)
        if self.metrics_compute_frequency is None:
            self.metrics_compute_frequency = num_iterations

        self.msts = self.fit_validator_train.msts
        self.model_arch_table = self.fit_validator_train.model_arch_table
        self.object_table = self.fit_validator_train.object_table
        self.metrics_iters = []
        self.object_map_col = 'object_map'
        self.custom_mst_keys = None
        if self.object_table is not None:
            self.populate_object_map()

        self.original_cuda_env = None
        if CUDA_VISIBLE_DEVICES_KEY in os.environ:
            self.original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]

        self.dist_key_mapping, self.images_per_seg_train = \
            get_image_count_per_seg_for_minibatched_data_from_db(
                self.source_table, self.dep_shape_cols[0])

        if self.validation_table:
            self.valid_mst_metric_eval_time = defaultdict(list)
            self.valid_mst_loss = defaultdict(list)
            self.valid_mst_metric = defaultdict(list)
            self.dist_key_mapping_valid, self.images_per_seg_valid = \
                get_image_count_per_seg_for_minibatched_data_from_db(
                    self.validation_table, self.val_dep_shape_cols[0])

        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
        self.max_dist_key = sorted(self.dist_keys)[-1]
        self.extra_dist_keys = []

        num_msts = self.num_msts = len(self.msts)
        num_dist_keys = len(self.dist_keys)

        if num_msts < num_dist_keys:
            self.msts_for_schedule = self.msts + [None] * \
                                     (num_dist_keys - num_msts)
        else:
            self.msts_for_schedule = self.msts
            if num_msts > num_dist_keys:
                for i in range(num_msts - num_dist_keys):
                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)

        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))

        random.shuffle(self.msts_for_schedule)

        # Ordered list of sql representations of each mst_key,
        #  including NULL's.  This will be used to pass the mst keys
        #  to the db as a sql ARRAY[]
        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
                for mst in self.msts_for_schedule ]

        # List of all dist_keys, including any extra dist keys beyond
        #  the # segments we'll be training on--these represent the
        #  segments models will rest on while not training, which
        #  may overlap with the ones that will have training on them.
        self.all_dist_keys = self.dist_keys + self.extra_dist_keys

        self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME
        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''

    def fit_multiple_model(self):
        self.init_schedule_tbl()
        self.init_model_output_tbl()
        self.init_model_info_tbl()

        # WARNING: set orca off to prevent unwanted redistribution
        with OptimizerControl(False):
            self.start_training_time = datetime.datetime.now()
            self.metrics_elapsed_start_time = time.time()
            self.train_multiple_model()
            self.end_training_time = datetime.datetime.now()

        # Update and cleanup metadata tables
        self.insert_info_table()
        self.create_model_summary_table()
        self.write_final_model_output_tbl()
        reset_cuda_env(self.original_cuda_env)

    def write_final_model_output_tbl(self):
        """
        1. drop original model table if exists
        2. rename temp to original
        :return:
        """
        final_output_table_create_query = """
                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
                                    CREATE TABLE {self.original_model_output_tbl} AS
                                    SELECT
                                        {self.mst_key_col}::INTEGER,
                                        {self.model_weights_col}::BYTEA,
                                        {self.model_arch_col}::JSON,
                                        {self.dist_key_col}::INTEGER
                                    FROM {self.model_output_tbl}
                                    DISTRIBUTED BY ({self.dist_key_col})
                                    """.format(self=self)
        plpy.execute(final_output_table_create_query)
        self.truncate_and_drop(self.model_output_tbl)

    def train_multiple_model(self):
        total_msts = len(self.all_mst_keys)
        DEBUG.start_timing('train_multiple_model_extra')

        for iter in range(1, self.num_iterations+1):
            for hop in range(total_msts):
                self.is_final_training_call = (iter == self.num_iterations and hop == total_msts-1)
                if hop == 0:
                    start_iteration = time.time()

                self.run_training(hop, hop==0 and iter==1)
                DEBUG.start_timing('train_multiple_model_extra')

                if hop == (total_msts - 1):
                    end_iteration = time.time()
                    self.info_str = "\tTime for training in iteration " \
                                    "{0}: {1} sec\n".format(iter,
                                                            end_iteration -
                                                            start_iteration)
                else:
                    self.rotate_schedule_tbl()

            if should_compute_metrics_this_iter(iter,
                                                self.metrics_compute_frequency,
                                                self.num_iterations):
                self.metrics_iters.append(iter)
                self.info_str += "\tTraining set after iteration {0}:".format(iter)
                self.evaluate_model(iter, self.source_table, True)
                if self.validation_table:
                    self.evaluate_model(iter, self.validation_table, False)
            plpy.info("\n"+self.info_str)
        plpy.execute("DROP TABLE IF EXISTS {self.schedule_tbl}".format(self=self))
        if self.use_caching:
            plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table}".format(self=self))

    def evaluate_model(self, iter, table, is_train):
        if is_train:
            label = "training"
        else:
            label = "validation"

        if is_train:
            mst_metric_eval_time = self.train_mst_metric_eval_time
            mst_loss = self.train_mst_loss
            mst_metric = self.train_mst_metric
            seg_ids = self.dist_key_mapping
            images_per_seg = self.images_per_seg_train
        else:
            mst_metric_eval_time = self.valid_mst_metric_eval_time
            mst_loss = self.valid_mst_loss
            mst_metric = self.valid_mst_metric
            seg_ids = self.dist_key_mapping_valid
            images_per_seg = self.images_per_seg_valid
            self.info_str += "\n\t\n\tValidation set after iteration {0}:".format(iter)
        total_eval_compute_time = 0
        for mst in self.msts:
            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
            DEBUG.start_timing('eval_compute_loss_and_metrics')
            eval_compute_time, metric, loss = compute_loss_and_metrics(
                self.schema_madlib, table, self.columns_dict,
                    "$madlib${0}$madlib$".format(
                    mst[self.compile_params_col]),
                    model_arch,
                    None,
                    self.use_gpus,
                    self.accessible_gpus_for_seg,
                    self.segments_per_host,
                seg_ids,
                images_per_seg,
                [], [], True,
                mst[self.object_map_col],
                self.model_output_tbl,
                mst[self.mst_key_col],
                    is_train)
            total_eval_compute_time += eval_compute_time
            mst_metric_eval_time[mst[self.mst_key_col]] \
                .append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time))
            mst_loss[mst[self.mst_key_col]].append(loss)
            mst_metric[mst[self.mst_key_col]].append(metric)
            self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(
                mst[self.mst_key_col], metric, loss)
        self.info_str += "\n\tTime for evaluating {0} dataset in iteration " \
                         "{1}: {2}".format(label, iter, total_eval_compute_time)
        DEBUG.print_timing('eval_model_total')

    def populate_object_map(self):
        builtin_losses = dir(losses)
        builtin_metrics = update_builtin_metrics(dir(metrics))

        # Track distinct custom functions in compile_params
        custom_fn_names = set()
        # Track their corresponding mst_keys to pass along the custom function
        # definition read from the object table.
        # For compile_params calling builtin functions the object_map is set to
        # None.
        custom_msts = []
        for mst in self.msts:
            compile_params = mst[self.compile_params_col]
            # We assume that the compile_param is validated as part
            # of the loading mst_table and thus not validating here
            # Also, it is validated later when we compile the model
            # on the segments
            compile_dict = convert_string_of_args_to_dict(compile_params)

            local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
            local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
            if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
                custom_fn_names.add(local_loss)
                custom_msts.append(mst)
            if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
                custom_fn_names.add(local_metric)
                custom_msts.append(mst)

        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)

        for mst in custom_msts:
            mst[self.object_map_col] = self.custom_fn_object_map

        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }

    def init_schedule_tbl(self):
        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'

        create_sched_query = """
            CREATE {self.unlogged_table} TABLE {self.schedule_tbl} AS
                WITH map AS
                    (SELECT
                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
                    )
                SELECT
                    map.{self.mst_key_col},
                    {self.model_id_col},
                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
                    map.{self.dist_key_col}
                FROM map LEFT JOIN {self.model_selection_table}
                    USING ({self.mst_key_col})
            DISTRIBUTED BY ({self.dist_key_col})
        """.format(self=self, mst_key_list=mst_key_list)
        plpy_execute(create_sched_query)

    def rotate_schedule_tbl(self):
        if self.rotate_schedule_tbl_plan is None:
            rotate_schedule_tbl_query = """
                CREATE {self.unlogged_table} TABLE {self.next_schedule_tbl} AS
                    SELECT
                        {self.mst_key_col},
                        {self.model_id_col},
                        {self.dist_key_col} AS {self.prev_dist_key_col},
                        COALESCE(
                            LEAD({self.dist_key_col})
                                OVER(ORDER BY {self.dist_key_col}),
                            FIRST_VALUE({self.dist_key_col})
                                OVER(ORDER BY {self.dist_key_col})
                        ) AS {self.dist_key_col}
                    FROM {self.schedule_tbl}
                DISTRIBUTED BY ({self.prev_dist_key_col})
            """.format(self=self)
            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)

        plpy.execute(self.rotate_schedule_tbl_plan)

        self.truncate_and_drop(self.schedule_tbl)
        plpy.execute("""
            ALTER TABLE {self.next_schedule_tbl}
            RENAME TO {self.schedule_tbl}
        """.format(self=self))

    def load_warm_start_weights(self):
        """
        For warm start, we need to copy any rows of the model output
        table provided by the user whose mst keys appear in the
        supplied model selection table.  We also copy over the
        compile & fit params from the model_selection_table, and
        the dist_key's from the schedule table.
        """
        load_warm_start_weights_query = """
            INSERT INTO {self.model_output_tbl}
                SELECT s.{self.mst_key_col},
                    o.{self.model_weights_col},
                    o.{self.model_arch_col},
                    m.{self.compile_params_col},
                    m.{self.fit_params_col},
                    NULL AS {self.object_map_col}, -- Fill in later
                    s.{self.dist_key_col}
                FROM {self.schedule_tbl} s
                    JOIN {self.model_selection_table} m
                        USING ({self.mst_key_col})
                    JOIN {self.original_model_output_tbl} o
                        USING ({self.mst_key_col})
        """.format(self=self)
        plpy_execute(load_warm_start_weights_query)

    def load_xfer_learning_weights(self, warm_start=False):
        """
            Copy transfer learning weights from
            model_arch table.  Ignore models with
            no xfer learning weights, these will
            be generated by keras and added one at a
            time later.
        """
        load_xfer_learning_weights_query = """
            INSERT INTO {self.model_output_tbl}
                SELECT s.{self.mst_key_col},
                    a.{self.model_weights_col},
                    a.{self.model_arch_col},
                    m.{self.compile_params_col},
                    m.{self.fit_params_col},
                    NULL AS {self.object_map_col}, -- Fill in later
                    s.{self.dist_key_col}
                FROM {self.schedule_tbl} s
                    JOIN {self.model_selection_table} m
                        USING ({self.mst_key_col})
                    JOIN {self.model_arch_table} a
                        ON m.{self.model_id_col} = a.{self.model_id_col}
                WHERE a.{self.model_weights_col} IS NOT NULL;
        """.format(self=self)
        plpy_execute(load_xfer_learning_weights_query)

    def init_model_output_tbl(self):
        DEBUG.start_timing('init_model_output_and_info')

        output_table_create_query = """
                                    CREATE {self.unlogged_table} TABLE {self.model_output_tbl}
                                    ({self.mst_key_col} INTEGER,
                                     {self.model_weights_col} BYTEA,
                                     {self.model_arch_col} JSON,
                                     {self.compile_params_col} TEXT,
                                     {self.fit_params_col} TEXT,
                                     {self.object_map_col} BYTEA,
                                     {self.dist_key_col} INTEGER,
                                     PRIMARY KEY ({self.dist_key_col})
                                    )
                                    DISTRIBUTED BY ({self.dist_key_col})
                                    """.format(self=self)
        plpy.execute(output_table_create_query)

        if self.warm_start:
            self.load_warm_start_weights()
        else:  # Note:  We only support xfer learning when warm_start=False
            self.load_xfer_learning_weights()

        res = plpy.execute("""
            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
        """.format(self=self))

        if res:
            initialized_msts = set([ row['mst_keys'] for row in res ])
        else:
            initialized_msts = set()

        # We've already bulk loaded all of the models with user-specified weights.
        #  For the rest of the models, we need to generate the weights for each
        #  by initializing them with keras and adding them one row at a time.
        #
        # TODO:  In the future, we should probably move the weight initialization
        #  into the transition function on the segments.  Here, we would just
        #  bulk load everything with a single query (or 2, for the warm start case),
        #  and leave the weights column as NULL for any model whose weights need
        #  to be randomly initialized.  Then in fit_transition, if prev_weights is
        #  NULL, and there is nothing in GD, it should just skip the call to
        #  set_weights(), and keras will automatically initialize them during
        #  model.from_json(model_arch).
        #
        #  This would be a very easy change for fit_multiple(), but might require
        #   some more work to support fit().  All of the segments there need to
        #   start with the same weights, so we'd at least have to pass a random
        #   seed to the transition function for keras to use.  Or generate a seed
        #   on the segments in some deterministic way that's the same for all.
        for index, mst in enumerate(self.msts_for_schedule):
            if mst is None:
                continue

            if mst['mst_key'] in initialized_msts:
                continue  # skip if we've already loaded this mst

            num_dist_keys = len(self.dist_keys)

            if index < num_dist_keys:
                dist_key = self.dist_keys[index]
            else:  # For models that won't be trained on first hop
                dist_key = self.extra_dist_keys[index - num_dist_keys]

            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
            serialized_weights = get_initial_weights(None, model_arch, None, False,
                                                     self.accessible_gpus_for_seg)

            output_table_add_row_query = """
                INSERT INTO {self.model_output_tbl} (
                    {self.mst_key_col},
                    {self.model_weights_col},
                    {self.model_arch_col},
                    {self.compile_params_col},
                    {self.fit_params_col},
                    {self.object_map_col},
                    {self.dist_key_col}
                ) VALUES (
                    $MADLIB${{{self.mst_key_col}}}$MADLIB$,
                    $1,
                    $2,
                    $MADLIB${{{self.compile_params_col}}}$MADLIB$,
                    $MADLIB${{{self.fit_params_col}}}$MADLIB$,
                    NULL, -- Fill in custom object_map soon
                    $3
                )
            """.format(self=self).format(**mst)

            output_table_add_row_query_prepared = plpy.prepare(
                output_table_add_row_query,
                ["BYTEA", "JSON", "INTEGER"]
            )

            plpy.execute(output_table_add_row_query_prepared,
                [ serialized_weights, model_arch, dist_key ]
            )

        if self.custom_mst_keys:
            custom_keys = '({})'.format(
                ','.join( map(str, self.custom_mst_keys) )
            )

            # Add object_map to any msts which use custom functions
            if self.add_object_maps_plan is None:
                self.add_object_maps_plan = plpy.prepare("""
                    UPDATE {self.model_output_tbl}
                        SET {self.object_map_col} = $1
                            WHERE {self.mst_key_col} IN {custom_keys}
                """.format(**locals()), ["BYTEA"])
            plpy.execute(self.add_object_maps_plan, [self.custom_fn_object_map])

    def init_model_info_tbl(self):
        info_table_create_query = """
            DROP TABLE IF EXISTS {self.model_info_tbl};
            CREATE TABLE {self.model_info_tbl} (
                {self.mst_key_col} INTEGER PRIMARY KEY,
                {self.model_id_col} INTEGER,
                {self.compile_params_col} TEXT,
                {self.fit_params_col} TEXT,
                model_type TEXT,
                model_size DOUBLE PRECISION,
                metrics_elapsed_time DOUBLE PRECISION[],
                metrics_type TEXT[],
                loss_type TEXT,
                training_metrics_final DOUBLE PRECISION,
                training_loss_final DOUBLE PRECISION,
                training_metrics DOUBLE PRECISION[],
                training_loss DOUBLE PRECISION[],
                validation_metrics_final DOUBLE PRECISION,
                validation_loss_final DOUBLE PRECISION,
                validation_metrics DOUBLE PRECISION[],
                validation_loss DOUBLE PRECISION[]
           ) """.format(self=self)

        plpy.execute(info_table_create_query)

        info_table_insert_query = """
            INSERT INTO {self.model_info_tbl} (
                {self.mst_key_col},
                {self.model_id_col},
                {self.compile_params_col},
                {self.fit_params_col},
                model_type,
                model_size
            )
            SELECT
                m.{self.mst_key_col},
                m.{self.model_id_col},
                m.{self.compile_params_col},
                m.{self.fit_params_col},
                '{model_type}',
                LENGTH(o.{self.model_weights_col})/1024.0
            FROM {self.model_selection_table} m JOIN {self.model_output_tbl} o
                USING ({self.mst_key_col})
        """.format(self=self,
                   model_type='madlib_keras')

        plpy.execute(info_table_insert_query)

        for mst in self.msts_for_schedule:
            if mst is None:
                continue

            metrics_list = get_metrics_from_compile_param(
                mst[self.compile_params_col])
            metrics_type = 'ARRAY{0}'.format(
                metrics_list) if metrics_list else 'NULL'
            loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
            loss_type = loss_type if loss_type else 'NULL'

            plpy.execute("""
                UPDATE {self.model_info_tbl} SET
                    metrics_type = {metrics_type},
                    loss_type = '{loss_type}'
                WHERE {self.mst_key_col} = {{{self.mst_key_col}}}
            """.format(self=self,
                       metrics_type=metrics_type,
                       loss_type=loss_type
              ).format(**mst))

        DEBUG.print_timing('init_model_output_and_info')

    def create_model_summary_table(self):
        if self.warm_start:
            plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
        source_summary_table = self.fit_validator_train.source_summary_table
        src_summary_dict = get_source_summary_table_dict(source_summary_table)

        class_values_colnames = [add_postfix(i, "_class_values") for i in self.mb_dep_var_cols]
        # class_values = src_summary_dict['class_values']
        class_values_type =[get_expr_type(i, source_summary_table) for i in class_values_colnames]
        # class_values_type = src_summary_dict['class_values_type']

        dependent_varname = src_summary_dict['dependent_varname']
        independent_varname = src_summary_dict['independent_varname']
        dep_name_list = ', '.join([quote_literal(i) for i in dependent_varname])
        ind_name_list = ', '.join([quote_literal(i) for i in independent_varname])

        norm_const = src_summary_dict['normalizing_const']
        self.validation_table = 'NULL' if self.validation_table is None \
            else '$MAD${0}$MAD$'.format(self.validation_table)
        if class_values_colnames is None:
            num_classes = 'NULL'
        else:
            num_classes = len(class_values_colnames)
        class_values_colnames = ' , '.join(class_values_colnames)
        name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
        descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
        object_table = 'NULL' if self.object_table is None \
            else '$MAD${0}$MAD$'.format(self.object_table)
        metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
        normalizing_const_colname = NORMALIZING_CONST_COLNAME
        float32_sql_type = FLOAT32_SQL_TYPE
        create_query = """
                CREATE TABLE {self.model_summary_table} AS
                SELECT
                    $MAD${self.source_table}$MAD$::TEXT AS source_table,
                    {self.validation_table}::TEXT AS validation_table,
                    $MAD${self.original_model_output_tbl}$MAD$::TEXT AS model,
                    $MAD${self.model_info_tbl}$MAD$::TEXT AS model_info,
                    ARRAY[{dep_name_list}]::TEXT[] AS dependent_varname,
                    ARRAY[{ind_name_list}]::TEXT[] AS independent_varname,
                    $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
                    $MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table,
                    {object_table}::TEXT AS object_table,
                    {self.num_iterations}::INTEGER AS num_iterations,
                    {self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
                    {self.warm_start} AS warm_start,
                    {name}::TEXT AS name,
                    {descr}::TEXT AS description,
                    '{self.start_training_time}'::TIMESTAMP AS start_training_time,
                    '{self.end_training_time}'::TIMESTAMP AS end_training_time,
                    '{self.version}'::TEXT AS madlib_version,
                    ARRAY[{num_classes}]::INTEGER[] AS num_classes,
                    {class_values_colnames},
                    dependent_vartype,
                    {norm_const}::{float32_sql_type} AS {normalizing_const_colname},
                    ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
                FROM {source_summary_table}
            """.format(**locals())
        plpy.execute(create_query)

    def update_info_table(self, mst, is_train):
        mst_key = mst[self.mst_key_col]
        metrics, metrics_final, metrics_elapsed_time = \
            "NULL", "NULL", "NULL"
        if is_train:
            mst_metric = self.train_mst_metric
            mst_metric_eval_time = self.train_mst_metric_eval_time
            mst_loss = self.train_mst_loss
        else:
            mst_metric = self.valid_mst_metric
            mst_metric_eval_time = self.valid_mst_metric_eval_time
            mst_loss = self.valid_mst_loss

        if mst_key in mst_metric:
            metrics_final, metrics = get_metrics_sql_string(mst_metric[mst_key])
            metrics_elapsed_time = mst_metric_eval_time[mst_key]
            metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time)
        loss_final, loss = get_metrics_sql_string(mst_loss[mst_key])

        if is_train:
            update_query = """
                           UPDATE {self.model_info_tbl} SET
                           training_metrics_final = {metrics_final},
                           training_loss_final = {loss_final},
                           metrics_elapsed_time = {metrics_elapsed_time},
                           training_metrics = {metrics},
                           training_loss = {loss}
                           WHERE {self.mst_key_col} = {mst_key}
                           """.format(**locals())
        else:
            update_query = """
                           UPDATE {self.model_info_tbl} SET
                           validation_metrics_final = {metrics_final},
                           validation_loss_final = {loss_final},
                           metrics_elapsed_time = {metrics_elapsed_time},
                           validation_metrics = {metrics},
                           validation_loss = {loss}
                           WHERE {self.mst_key_col} = {mst_key}
                           """.format(**locals())
        plpy.execute(update_query)

    def insert_info_table(self):
        for mst in self.msts:
            self.update_info_table(mst, True)
            if self.validation_table:
                self.update_info_table(mst, False)

    def run_training(self, hop, is_very_first_hop):
        """
               This method is called once per hop from the main fit_multiple_model loop.
            The hop param here identifies the hop number within an iteration, starting
            over each iteration at hop 0.  It ranges from 0 to the greater of either
            the number of model configs in the mst table or the number of segments with
            data on them.  This ensures that each model config gets paired with each
            data segment exactly once per iteration.

               If there are more segments than model configs, then there will be some
            NULL mst_key rows in the model_input & model_output tables.  If instead there
            are more mst keys than segments, then the models not being trained this round
            will have "extra" dist keys, meaning dist_key > max_dist_key where max_dist_key
            is the largest dist key in the source table.  Each of these will be distributed
            on some segment, but we don't care which.

            There are 2 main tasks performed in run_training():
                1.)  The actual hop - each of the rows in the model_output table from the
                     previous round are permuted onto the next segment in a round-robin
                     fashion... the result is saved as the model_input table for this round.
                     The bulk of the data in each row is the model weights.  The schedule
                     table is there to guides each of these models from their previous location
                     to their new scheduled location, where they will train this round.

                2.)  Calling fit_transition_multiple_model() - We join the model_input
                     table with the data source table to train the models on the data local
                     to their segment.  The most important concern here is making sure that
                     the plan for this query does not redistribute any of the model weights.
                     The dist keys are carefully chosen so that there should be no data
                     movement--the only time the model weights move is during the actual
                     hop.  Without caching, the models are trained one row at a time,
                     conceptually similar to a UDA.  With caching enabled, all of the
                     rows are combined in memory on the very first round.  So after that
                     we replace the source table with an empty table (cached_source_table),
                     containing only 1 row per segment, with dist keys but no actual data.
        """
        # NOTE: In the DL module, we want to avoid CREATING TEMP tables
        # (creates a slice which stays until the session is disconnected)
        # or minimize writing queries that generate plans with Motions (creating
        # multiple slices on segments).
        # This is mainly to avoid any GPU memory allocation failures. Since GPU
        # memory allocation is tied to the process where it is initialized, failures
        # may occur when a newly created slice(process) tries allocating GPU memory
        # which is already allocated by a previously created slice(process).
        # Therefore we want to have queries that do not add motions and all the
        # sub-queries running Keras/tensorflow operations reuse the same slice(process)
        # that was used for initializing GPU memory.

        DEBUG.start_timing("run_training")
        if hop > 0:
            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
            DEBUG.start_timing("hop")

            if self.hop_plan is None:
                self.hop_plan = plpy_prepare("""
                    CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
                        SELECT o.{self.mst_key_col},
                               o.{self.model_weights_col},
                               o.{self.model_arch_col},
                               o.{self.compile_params_col},
                               o.{self.fit_params_col},
                               o.{self.object_map_col},
                               s.{self.dist_key_col}
                        FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
                            ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
                        DISTRIBUTED BY ({self.dist_key_col})
                    """.format(self=self)
                )

            plpy_execute(self.hop_plan)

            DEBUG.print_timing("hop")
            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')

            self.truncate_and_drop(self.model_output_tbl)
        else:
            # Skip hop if it's the first in an iteration, just rename
            plpy.execute("""
                ALTER TABLE {self.model_output_tbl}
                    RENAME TO {self.model_input_tbl}
            """.format(self=self))

        #TODO: Fix these to add multi io
        dep_shape_col = self.dep_shape_cols[0]
        ind_shape_col = self.ind_shape_cols[0]
        dep_var_col = self.mb_dep_var_cols[0]
        indep_var_col = self.mb_indep_var_cols[0]
        source_table = self.source_table

        if self.use_caching:
            # Caching populates the independent_var and dependent_var into the cache on the very first hop
            # For the very_first_hop, we want to run the transition function on all segments, including
            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
            # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
            # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
            # table to join for referencing the dist_key
            if is_very_first_hop:
                plpy.execute("""
                    DROP TABLE IF EXISTS {self.cached_source_table};
                    CREATE {self.unlogged_table} TABLE {self.cached_source_table} AS
                        SELECT {self.dist_key_col} FROM {self.source_table}
                            GROUP BY {self.dist_key_col}
                                DISTRIBUTED BY({self.dist_key_col});
                    """.format(self=self))
            else:
                dep_shape_col = 'NULL'
                ind_shape_col = 'NULL'
                dep_var_col = 'NULL'
                indep_var_col = 'NULL'
                source_table = self.cached_source_table

            if is_very_first_hop or self.is_final_training_call:
                num_msts = self.num_msts
                num_segs = len(self.dist_keys)
                if num_msts < num_segs:
                    # Add some empty rows, so that cache gets
                    #  populated or deleted on all segments, not
                    #  just those with models on them currently.
                    insert_empty_rows_query = """
                        INSERT INTO {self.model_input_tbl} (__dist_key__)
                            SELECT __dist_key__ FROM {self.schedule_tbl}
                                WHERE {self.mst_key_col} IS NULL
                    """.format(self=self)
                    plpy_execute(insert_empty_rows_query)

        DEBUG.start_timing("udf")
        if self.udf_plan is None:
            self.udf_plan = plpy_prepare("""
                CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
                SELECT
                    model_in.{self.mst_key_col},
                    CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
                    THEN
                        model_in.{self.model_weights_col}
                    ELSE
                        {self.schema_madlib}.fit_transition_multiple_model(
                            ARRAY[{dep_var_col}]::BYTEA[],
                            ARRAY[{indep_var_col}]::BYTEA[],
                            ARRAY[{dep_shape_col}]::INTEGER[],
                            ARRAY[{ind_shape_col}]::INTEGER[],
                            model_in.{self.model_arch_col}::TEXT,
                            model_in.{self.compile_params_col}::TEXT,
                            model_in.{self.fit_params_col}::TEXT,
                            src.{self.dist_key_col},
                            ARRAY{self.dist_key_mapping},
                            src.{self.gp_segment_id_col},
                            ARRAY{self.segments_per_host},
                            ARRAY{self.images_per_seg_train},
                            ARRAY{self.accessible_gpus_for_seg},
                            model_in.{self.model_weights_col}::BYTEA,
                            $1::BOOLEAN, -- is_final_training_call
                            {self.use_caching}::BOOLEAN,
                            model_in.{self.object_map_col}::BYTEA
                        )
                    END::BYTEA AS {self.model_weights_col},
                    model_in.{self.model_arch_col},
                    model_in.{self.compile_params_col},
                    model_in.{self.fit_params_col},
                    model_in.{self.object_map_col},
                    model_in.{self.dist_key_col}
                FROM {self.model_input_tbl} model_in
                    LEFT JOIN {source_table} src
                    USING ({self.dist_key_col})
                DISTRIBUTED BY({self.dist_key_col})
                """.format(dep_var_col=dep_var_col,
                           indep_var_col=indep_var_col,
                           dep_shape_col=dep_shape_col,
                           ind_shape_col=ind_shape_col,
                           source_table=source_table,
                           self=self
                           ),
                [ 'BOOLEAN' ]
            )

        try:
            plpy_execute(self.udf_plan, [ self.is_final_training_call ] )
        except plpy.SPIError as e:
            msg = e.message
            if not 'UDF_Detail' in msg:
                raise e
            e.message, detail = msg.split('UDF_Detail')
            # Extract Traceback from segment, add to
            #  DETAIL of error message on coordinator
            e.args = (e.message,)
            spidata = list(e.spidata)
            spidata[1] = detail
            e.spidata = tuple(spidata)
            raise e

        DEBUG.print_timing("udf")

        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE {self.model_weights_col} IS NULL".format(self=self))


        self.truncate_and_drop(self.model_input_tbl)

        if self.use_caching and is_very_first_hop:
            # Throw away plan for source_table, force generation of a new one
            #  next time for cached_source_table
            self.udf_plan = None

        DEBUG.print_timing("run_training")

    def truncate_and_drop(self, table):
        """
        This function truncates and drops one of the intermediate tables used
        during an iteration (model_input_tbl, model_output_tbl, schedule_tbl).
        If available, set the `dev_opt_unsafe_truncate_in_subtransaction` guc
        so that the truncate command can release the disk space. The disk space
        will be released immediately and hence the model_output table won't grow
        in size with each hop.

        Without this guc, the disk space won't be released and each
        call to TRUNCATE or DROP will keep adding to the disk space. The
        disk space will only be released when the query is completed.

        The guc can cause data loss if not used properly. Since truncate will
        actually clear the disk space immediately, there is no way to recover to
        the state before truncate was called on that table. So this guc should only
        be set for intermediate tables and never for tables created outside the
        scope of the fit_multiple udf.

        """

        with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
            plpy.execute("""
                TRUNCATE TABLE {table};
                DROP TABLE {table}
            """.format(table=table))
