# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import plpy

from model_arch_info import *
from madlib_keras_helper import *
from madlib_keras_validator import *
from predict_input_params import PredictParamsProcessor
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import unique_string
from utilities.utilities import get_psql_type
from utilities.utilities import split_quoted_delimited_str
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import quote_ident

from madlib_keras_wrapper import *

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *

class BasePredict():
    def __init__(self, schema_madlib, table_to_validate, test_table, id_col,
                 independent_varname, output_table, pred_type, use_gpus, module_name):
        self.schema_madlib = schema_madlib
        self.table_to_validate = table_to_validate
        self.test_table = test_table
        self.id_col = id_col
        self.independent_varname = split_quoted_delimited_str(independent_varname)
        self.output_table = output_table
        self.pred_type = pred_type
        self.module_name = module_name

        self.use_gpus = use_gpus if use_gpus else False
        self.segments_per_host = get_data_distribution_per_segment(test_table)
        if self.use_gpus:
            accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib,
                                                                  self.segments_per_host,
                                                                  self.module_name)
            _assert(len(set(accessible_gpus_for_seg)) == 1,
                '{0}: Asymmetric gpu configurations are not supported'.format(self.module_name))
            self.gpus_per_host = accessible_gpus_for_seg[0]
        else:
            self.gpus_per_host = 0

        self._set_default_pred_type()

    def _set_default_pred_type(self):
        self.pred_type = 'prob' if self.pred_type is None else self.pred_type
        self.is_response = True if self.pred_type == 'response' else False
        self.pred_type = 1 if self.is_response else self.pred_type
        self.get_all = True if self.pred_type == 'prob' else False
        self.use_ratio = True if self.pred_type < 1 else False

    def call_internal_keras(self):

        pred_col_name = 'prob'
        pred_col_type = 'double precision'

        class_values = strip_trailing_nulls_from_class_values(self.class_values)
        gp_segment_id_col, seg_ids_test, \
        images_per_seg_test = get_image_count_per_seg_for_non_minibatched_data_from_db(
            self.test_table)

        if self.pred_type == 1:
            rank_create_sql = ""

        self.pred_vartype = [i.strip('[]') for i in self.dependent_vartype]
        unnest_sql = []
        full_class_name_list = []
        full_class_value_list = []

        for i in range(self.dependent_var_count):

            if self.pred_vartype[i] in ['text', 'character varying', 'varchar']:

                unnest_sql.append("unnest(ARRAY{0}) AS {1}".format(
                    ['NULL' if j is None else j for j in class_values[i]],
                    quote_ident(self.dependent_varname[i])))
            else:

                unnest_sql.append("unnest(ARRAY[{0}]) AS {1}".format(
                    ','.join(['NULL' if j is None else str(j) for j in class_values[i]]),
                    quote_ident(self.dependent_varname[i])))


            for j in class_values[i]:
                tmp_class_name = self.dependent_varname[i] if self.dependent_varname[i] is not None else "NULL::TEXT"
                full_class_name_list.append(tmp_class_name)
                tmp_class_value = j if j is not None else "NULL::TEXT"
                full_class_value_list.append(tmp_class_value)

        unnest_sql = """unnest(ARRAY{full_class_name_list}::TEXT[]) AS class_name,
                        unnest(ARRAY{full_class_value_list}::TEXT[]) AS class_value
                        """.format(**locals())

        if self.get_all:
            filter_sql = ""
        elif self.use_ratio:
            filter_sql = "WHERE {pred_col_name} > {self.pred_type}".format(**locals())
        else:
            filter_sql = "WHERE rank <= {self.pred_type}".format(**locals())

        select_segmentid_comma = ""
        group_by_clause = ""
        join_cond_on_segmentid = ""
        if not is_platform_pg():
            select_segmentid_comma = "{self.test_table}.gp_segment_id AS gp_segment_id,".format(self=self)
            group_by_clause = "GROUP BY {self.test_table}.gp_segment_id".format(self=self)
            join_cond_on_segmentid = "{self.test_table}.gp_segment_id=min_ctid.gp_segment_id AND".format(self=self)

        # Calling CREATE TABLE instead of CTAS, to ensure that the plan_cache_mode
        # guc codepath is called when passing in the weights
        sql = """
            CREATE TABLE {self.output_table}
                ({self.id_col} {self.id_col_type},
                 class_name TEXT,
                 class_value TEXT,
                 {pred_col_name} {pred_col_type},
                 rank INTEGER)
            """.format(**locals())
        plpy.execute(sql)

        independent_varname_sql = ["{0}::REAL[]".format(quote_ident(i)) for i in self.independent_varname]

        while len(independent_varname_sql) < 5:
            independent_varname_sql.append("NULL::REAL[]")
        independent_varname_sql = ', '.join(independent_varname_sql)

        # Passing huge model weights to internal_keras_predict() for each row
        # resulted in slowness of overall madlib_keras_predict().
        # To avoid this, a CASE is added to pass the model weights only for
        # the very first row(min(ctid)) that is fetched on each segment and NULL
        # for the other rows.

        rank_sql = """ row_number() OVER (PARTITION BY {self.id_col}, class_name
                       ORDER BY {pred_col_name} DESC) AS rank
                       """.format(**locals())
        sql1 = """
            INSERT INTO {self.output_table}
            SELECT *
            FROM (
                SELECT *, {rank_sql}
                FROM (
                    SELECT  {self.id_col}::{self.id_col_type},
                            {unnest_sql},
                            unnest(
                            {self.schema_madlib}.internal_keras_predict
                                ({independent_varname_sql},
                                $1,
                                CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
                                {self.normalizing_const},
                                {gp_segment_id_col},
                                ARRAY{seg_ids_test},
                                ARRAY{images_per_seg_test},
                                {self.gpus_per_host},
                                ARRAY{self.segments_per_host})) AS prob

                            FROM {self.test_table}
                            LEFT JOIN
                                (SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
                                 FROM {self.test_table}
                                 {group_by_clause}) min_ctid
                            ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
                ) __subq1__
            ) __subq2__
            {filter_sql}
            """.format(**locals())
        predict_query = plpy.prepare(sql1, ["text", "bytea"])
        plpy.execute(predict_query, [self.model_arch, self.model_weights])


        if self.is_response:
            # Drop the rank column since it is irrelevant
            plpy.execute("""
                ALTER TABLE {self.output_table}
                DROP COLUMN rank
                """.format(**locals()))

    def set_default_class_values(self, in_class_values, dependent_var_count):

        self.class_values = []
        num_classes = get_num_classes(self.model_arch, dependent_var_count)
        for counter, i in enumerate(in_class_values):
            if (i is None) or (i==[None]):
                self.class_values.append(range(0, num_classes[counter]))
            else:
                self.class_values.append(i)


@MinWarning("warning")
class Predict(BasePredict):
    def __init__(self, schema_madlib, model_table,
                 test_table, id_col, independent_varname,
                 output_table, pred_type, use_gpus,
                 mst_key, **kwargs):

        self.module_name = 'madlib_keras_predict'
        self.model_table = model_table
        self.mst_key = mst_key
        self.is_mult_model = mst_key is not None
        if self.model_table:
            self.model_summary_table = add_postfix(self.model_table, "_summary")

        BasePredict.__init__(self, schema_madlib, model_table, test_table,
                              id_col, independent_varname,
                              output_table, pred_type,
                              use_gpus, self.module_name)
        param_proc = PredictParamsProcessor(self.model_table, self.module_name, self.mst_key)
        if self.is_mult_model:
            self.temp_summary_view = param_proc.model_summary_table
            self.model_summary_table = self.temp_summary_view
        self.dependent_vartype = param_proc.get_dependent_vartype()
        self.model_weights = param_proc.get_model_weights()
        self.model_arch = param_proc.get_model_arch()

        self.dependent_varname = param_proc.get_dependent_varname()
        self.dependent_var_count = len(self.dependent_varname)
        class_values = []
        for dep in self.dependent_varname:
            class_values.append(param_proc.get_class_values(dep))
        self.set_default_class_values(class_values, self.dependent_var_count)
        self.normalizing_const = param_proc.get_normalizing_const()

        self.validate()
        self.id_col_type = get_expr_type(self.id_col, self.test_table)
        BasePredict.call_internal_keras(self)
        if self.is_mult_model:
            plpy.execute("DROP VIEW IF EXISTS {0}".format(self.temp_summary_view))

    def validate(self):
        input_tbl_valid(self.model_table, self.module_name)
        if self.is_mult_model and not columns_exist_in_table(self.model_table, ['mst_key']):
            plpy.error("{self.module_name}: Single model should not pass mst_key".format(**locals()))
        if not self.is_mult_model and columns_exist_in_table(self.model_table, ['mst_key']):
            plpy.error("{self.module_name}: Multi-model needs to pass mst_key".format(**locals()))
        InputValidator.validate_predict_evaluate_tables(
            self.module_name, self.model_table, self.model_summary_table,
            self.test_table, self.output_table)

        InputValidator.validate_id_in_test_tbl(
            self.module_name, self.test_table, self.id_col)

        input_shape = get_input_shape(self.model_arch)
        InputValidator.validate_pred_type(
            self.module_name, self.pred_type, self.class_values)
        InputValidator.validate_input_shape(
            self.test_table, self.independent_varname, input_shape, 1)

@MinWarning("warning")
class PredictBYOM(BasePredict):
    def __init__(self, schema_madlib, model_arch_table, model_id,
                 test_table, id_col, independent_varname, output_table,
                 pred_type, use_gpus, class_values, normalizing_const,
                 dependent_count, **kwargs):

        self.module_name='madlib_keras_predict_byom'
        self.model_arch_table = model_arch_table
        self.model_id = model_id
        self.class_values = class_values
        self.normalizing_const = normalizing_const
        self.dependent_var_count = dependent_count

        if self.dependent_var_count == 1:
            self.dependent_varname = ['dependent_var']
        else:
            self.dependent_varname = ['dependent_var_{0}'.format(i) for i in range(self.dependent_var_count)]
        BasePredict.__init__(self, schema_madlib, model_arch_table,
                             test_table, id_col, independent_varname,
                             output_table, pred_type, use_gpus, self.module_name)
        self.dependent_vartype = []
        if self.class_values:
            for i in self.class_values:
                self.dependent_vartype.append(get_psql_type(i[0]))
        else:
            self.class_values = [None]*self.dependent_var_count
            if self.pred_type == 1:
                self.dependent_vartype = ['text']*self.dependent_var_count
            else:
                self.dependent_vartype = ['double precision']*self.dependent_var_count

        ## Set default values for norm const and class_values
        # use_gpus and pred_type are defaulted in base_predict's init
        self.normalizing_const = normalizing_const
        if self.normalizing_const is None:
            self.normalizing_const = DEFAULT_NORMALIZING_CONST
        InputValidator.validate_predict_byom_tables(
            self.module_name, self.model_arch_table, self.model_id,
            self.test_table, self.id_col, self.output_table,
            self.independent_varname)
        self.validate_and_set_defaults()
        self.id_col_type = get_expr_type(self.id_col, self.test_table)
        BasePredict.call_internal_keras(self)

    def validate_and_set_defaults(self):
        # Set some defaults first and then validate and then set some more defaults
        self.model_arch, self.model_weights = get_model_arch_weights(
            self.model_arch_table, self.model_id)
        # Assert model_weights and model_arch are not empty.
        _assert(self.model_weights and self.model_arch,
                "{0}: Model weights and architecture should not be NULL.".format(
                    self.module_name))
        self.set_default_class_values(self.class_values, self.dependent_var_count)

        InputValidator.validate_pred_type(
            self.module_name, self.pred_type, self.class_values)
        InputValidator.validate_normalizing_const(
            self.module_name, self.normalizing_const)

        # TODO: Fix this validation
        # The current method looks at the 'units' keyword which doesn't mean
        # anything because every layer has it. It was passing because the layers
        # are traversed in order. It won't work for multi-io and prone to breaking
        # in the regular case.

        # InputValidator.validate_class_values(
        #     self.module_name, self.class_values, self.pred_type, self.model_arch)
        InputValidator.validate_input_shape(
            self.test_table, self.independent_varname,
            get_input_shape(self.model_arch), 1)

def internal_keras_predict_wide(independent_var, independent_var2,
                                independent_var3, independent_var4, independent_var5,
                                model_architecture, model_weights,
                                normalizing_const, current_seg_id, seg_ids,
                                images_per_seg, gpus_per_host, segments_per_host,
                                **kwargs):
    return internal_keras_predict(
        [independent_var, independent_var2, independent_var3, independent_var4, independent_var5],
        model_architecture, model_weights, normalizing_const, current_seg_id,
        seg_ids, images_per_seg, gpus_per_host, segments_per_host,
        **kwargs)

def internal_keras_predict(independent_var, model_architecture, model_weights,
                           normalizing_const, current_seg_id, seg_ids,
                           images_per_seg, gpus_per_host, segments_per_host,
                           **kwargs):
    SD = kwargs['SD']
    model_key = 'segment_model_predict'
    row_count_key = 'row_count'
    try:
        device_name = get_device_name_and_set_cuda_env(gpus_per_host, current_seg_id)
        if model_key not in SD:
            set_keras_session(device_name, gpus_per_host, segments_per_host[current_seg_id])
            model = model_from_json(model_architecture)
            set_model_weights(model, model_weights)
            SD[model_key] = model
            SD[row_count_key] = 0
        else:
            model = SD[model_key]
        SD[row_count_key] += 1

        # Since the test data isn't mini-batched,
        # we have to make sure that the test data np array has the same
        # number of dimensions as input_shape. So we add a dimension to x.

        independent_var_filtered = []
        for i in independent_var:
            if i is not None:
                independent_var_filtered.append(expand_input_dims(i)/normalizing_const)
        with tf.device(device_name):
            probs = model.predict(independent_var_filtered)
        # probs is a list containing a list of probability values, of all
        # class levels. Since we are assuming each input is a single image,
        # and not mini-batched, this list contains exactly one list in it,
        # so return back the first list in probs.
        result = []
        if len(independent_var_filtered) > 1:
            for i in probs:
                for j in i[0]:
                    result.append(j)
        else:
            result = probs[0]
        total_images = get_image_count_per_seg_from_array(seg_ids.index(current_seg_id),
                                                          images_per_seg)

        if SD[row_count_key] == total_images:
            SD.pop(model_key, None)
            SD.pop(row_count_key, None)
            clear_keras_session()
        return result
    except Exception as ex:
        SD.pop(model_key, None)
        SD.pop(row_count_key, None)
        clear_keras_session()
        plpy.error(ex)

def predict_help(schema_madlib, message, **kwargs):
    """
    Help function for keras predict

    Args:
        @param schema_madlib
        @param message: string, Help message string
        @param kwargs

    Returns:
        String. Help/usage information
    """
    if not message:
        help_string = """
-----------------------------------------------------------------------
                            SUMMARY
-----------------------------------------------------------------------
This function allows the user to predict using a madlib_keras_fit trained
model.

For more details on function usage:
    SELECT {schema_madlib}.madlib_keras_predict('usage')
            """
    elif message in ['usage', 'help', '?']:
        help_string = """
-----------------------------------------------------------------------
                            USAGE
-----------------------------------------------------------------------
 SELECT {schema_madlib}.madlib_keras_predict(
    model_table,    --  Name of the table containing the model
    test_table,     --  Name of the table containing the evaluation dataset
    id_col,         --  Name of the id column in the test data table
    independent_varname,    --  Name of the column with independent
                                variables in the test table
    output_table,   --  Name of the output table
    pred_type,      --  The type of the desired output
    use_gpus,       --  Flag for enabling GPU support
    mst_key         --  Identifier for the desired model out of multimodel
                        training output
    )
 );

-----------------------------------------------------------------------
                            OUTPUT
-----------------------------------------------------------------------
The output table ('output_table' above) contains the following columns:

id:                     Gives the 'id' for each prediction,
                        corresponding to each row from the test_table.
dependent_varname:      The estimated class.
prob:                   The probability of a given class.
rank:                   The rank of the estimation.
"""
    else:
        help_string = "No such option. Use {schema_madlib}.madlib_keras_predict()"

    return help_string.format(schema_madlib=schema_madlib)

def predict_byom_help(schema_madlib, message, **kwargs):
    """
    Help function for keras predict

    Args:
        @param schema_madlib
        @param message: string, Help message string
        @param kwargs

    Returns:
        String. Help/usage information
    """
    if not message:
        help_string = """
-----------------------------------------------------------------------
                            SUMMARY
-----------------------------------------------------------------------
This function allows the user to predict with their own pre trained model (note
that this model doesn't have to be trained using MADlib.)

For more details on function usage:
    SELECT {schema_madlib}.madlib_keras_predict_byom('usage')
            """
    elif message in ['usage', 'help', '?']:
        help_string = """
-----------------------------------------------------------------------
                            USAGE
-----------------------------------------------------------------------
 SELECT {schema_madlib}.madlib_keras_predict_byom(
    model_arch_table,       --  Name of the table containing the model architecture
                                and the pre trained model weights
    model_id,               --  This is the id in 'model_arch_table' containing the
                                model architecture
    test_table,             --  Name of the table containing the evaluation dataset
    id_col,                 --  Name of the id column in the test data table
    independent_varname,    --  Name of the column with independent
                                variables in the test table
    output_table,           --  Name of the output table
    pred_type,              --  The type of the desired output
    use_gpus,               --  Flag for enabling GPU support
    class_values,           --  List of class labels that were used while training the
                                model. If class_values is passed in as NULL, the output
                                table will have a column named 'prob' which is an array
                                of probabilities of all the classes.
                                Otherwise if class_values is not NULL, then the output
                                table will contain a column for each class/label from
                                the training data
    normalizing_const,      --  Normalizing constant used for standardizing arrays in
                                independent_varname
    )
 );

-----------------------------------------------------------------------
                            OUTPUT
-----------------------------------------------------------------------
The output table ('output_table' above) contains the following columns:

id:                     Gives the 'id' for each prediction,
                        corresponding to each row from the test_table.
dependent_varname:      The estimated class.
prob:                   The probability of a given class.
rank:                   The rank of the estimation.
"""
    else:
        help_string = "No such option. Use {schema_madlib}.madlib_keras_predict_byom()"

    return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------
