# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations

import dill
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import get_col_name_type_sql_string
from utilities.utilities import current_user
from utilities.utilities import is_superuser
from utilities.utilities import get_schema
from utilities.validate_args import columns_missing_from_table
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import quote_ident
from utilities.validate_args import unquote_ident
from utilities.validate_args import table_exists

module_name = 'Keras Custom Function'
class CustomFunctionSchema:
    """Expected format of custom function table.
       Example uses:

           from utilities.validate_args import columns_missing_from_table
           from madlib_keras_custom_function import CustomFunctionSchema

           # Validate names in cols list against actual table
           missing_cols = columns_missing_from_table('my_custom_fn_table', CustomFunctionSchema.col_names)

           # Get function object from table, without hard coding column names
           sql = "SELECT {object} FROM {table} WHERE {id} = {my_id}"
                 .format(object=CustomFunctionSchema.FN_OBJ,
                         table='my_custom_fn_table',
                         id=CustomFunctionSchema.FN_ID,
                         my_id=1)
           object = plpy.execute(sql)[0]

    """
    FN_ID = 'id'
    FN_NAME = 'name'
    FN_OBJ = 'object'
    FN_DESC = 'description'
    col_names = (FN_ID, FN_NAME, FN_DESC, FN_OBJ)
    col_types = ('SERIAL', 'TEXT', 'TEXT', 'BYTEA')

def _validate_object(object, **kwargs):
    _assert(object is not None, "{0}: function object cannot be NULL!".format(module_name))
    try:
        obj=dill.loads(object)
    except Exception as e:
        plpy.error("{0}: Invalid function object".format(module_name, e))

@MinWarning("error")
def load_custom_function(schema_madlib, object_table, object, name, description=None, **kwargs):

    if object_table is not None:
        object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
    _validate_object(object)
    _assert(name is not None,
            "{0}: function name cannot be NULL!".format(module_name))
    _assert(is_superuser(current_user()), "DL: The user has to have admin "\
        "privilages to load a custom function")
    try:
        if not table_exists(object_table):
            col_defs = get_col_name_type_sql_string(CustomFunctionSchema.col_names,
                                                    CustomFunctionSchema.col_types)

            sql = """CREATE TABLE {object_table}
                                  ({col_defs}, PRIMARY KEY({fn_name}))
                """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals())

            plpy.execute(sql, 0)
            # Using plpy.notice here as this function can be called:
            # 1. Directly by the user, we do want to display to the user
            #    if we create a new table or later the function name that
            #    is added to the table
            # 2. From load_top_k_accuracy_function, since plpy.info
            #    displays the query context when called from the function
            #    there is a very verbose output and cannot be suppressed with
            #    MinWarning decorator as INFO is always displayed irrespective
            #    of what the decorator sets the client_min_messages to.
            #    Therefore, instead we print this information as a NOTICE
            #    when called directly by the user and suppress it by setting
            #    MinWarning decorator to 'error' level in the calling function.
            plpy.notice("{0}: Created new custom function table {1}." \
                      .format(module_name, object_table))
            plpy.execute("GRANT SELECT ON {0} TO PUBLIC".format(object_table))
        else:
            missing_cols = columns_missing_from_table(object_table,
                                                      CustomFunctionSchema.col_names)
            if len(missing_cols) > 0:
                plpy.error("{0}: Invalid custom function table {1},"
                           " missing columns: {2}".format(module_name,
                                                          object_table,
                                                          missing_cols))

        insert_query = plpy.prepare("INSERT INTO {object_table} "
                                    "VALUES(DEFAULT, $1, $2, $3);".format(**locals()),
                                    CustomFunctionSchema.col_types[1:])

        plpy.execute(insert_query,[name, description, object], 0)
    # spiexceptions.UniqueViolation is only supported for PG>=9.2. For
    # GP5(based of PG8.4) it cannot be used. Therefore, checking exception
    # message for duplicate key error.
    except Exception as e:
        if 'duplicate key' in e.message:
            plpy.error("Function '{0}' already exists in {1}".format(name, object_table))
        plpy.error(e)

    plpy.notice("{0}: Added function {1} to {2} table".
              format(module_name, name, object_table))

@MinWarning("error")
def delete_custom_function(schema_madlib, object_table, id=None, name=None, **kwargs):

    if object_table is not None:
        schema_name = get_schema(object_table)
        if schema_name is None:
            object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
        elif schema_name != schema_madlib:
            plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib))

    input_tbl_valid(object_table, "Keras Custom Funtion")
    _assert(id is not None or name is not None,
            "{0}: function id/name cannot be NULL! " \
            "Use \"SELECT delete_custom_function('usage')\" for help.".format(module_name))

    missing_cols = columns_missing_from_table(object_table, CustomFunctionSchema.col_names)
    if len(missing_cols) > 0:
        plpy.error("{0}: Invalid custom function table {1},"
                   " missing columns: {2}".format(module_name, object_table,
                                                  missing_cols))

    if id is not None:
        sql = """
               DELETE FROM {object_table} WHERE {fn_id}={id}
              """.format(fn_id=CustomFunctionSchema.FN_ID,**locals())
    else:
        sql = """
               DELETE FROM {object_table} WHERE {fn_name}=$${name}$$
              """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals())
    res = plpy.execute(sql, 0)

    if res.nrows() > 0:
        plpy.notice("{0}: Object id {1} has been deleted from {2}.".
                  format(module_name, id, object_table))
    else:
        plpy.error("{0}: Object id {1} not found".format(module_name, id))

    sql = "SELECT {0} FROM {1}".format(
        CustomFunctionSchema.FN_ID, object_table)
    res = plpy.execute(sql, 0)
    if not res:
        plpy.notice("{0}: Dropping empty custom keras function table " \
                  "table {1}".format(module_name, object_table))
        sql = "DROP TABLE {0}".format(object_table)
        plpy.execute(sql, 0)

def update_builtin_metrics(builtin_metrics):
    builtin_metrics.append('accuracy')
    builtin_metrics.append('acc')
    builtin_metrics.append('crossentropy')
    builtin_metrics.append('ce')
    return builtin_metrics

@MinWarning("error")
def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs):

    object_table = quote_ident(object_table)
    _assert(k > 0,
        "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name))
    fn_name = "top_{k}_accuracy".format(**locals())

    sql = """
        SELECT  {schema_madlib}.load_custom_function(\'{object_table}\',
                {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'),
                \'{fn_name}\',
                \'returns {fn_name}\');
        """.format(**locals())
    plpy.execute(sql)
    # As this function allocates the name for the top_k_accuracy function,
    # printing it out here so the user doesn't need to lookup for the
    # newly added custom function name in the object_table
    plpy.info("{0}: Added function \'{1}\' to \'{2}\' table".
                format(module_name, fn_name, object_table))
    return

class KerasCustomFunctionDocumentation:
    @staticmethod
    def _returnHelpMsg(schema_madlib, message, summary, usage, method):
        if not message:
            return summary
        elif message.lower() in ('usage', 'help', '?'):
            return usage
        return """
            No such option. Use "SELECT {schema_madlib}.{method}()"
            for help.
        """.format(**locals())

    @staticmethod
    def load_custom_function_help(schema_madlib, message):
        method = "load_custom_function"
        summary = """
        ----------------------------------------------------------------
                            SUMMARY
        ----------------------------------------------------------------
        The user can specify custom functions as part of the parameters
        passed to madlib_keras_fit()/madlib_keras_fit_multiple(). These
        custom function(s) definition must be stored in a table to pass.
        This is a helper function to help users insert object(BYTEA) of
        the function definitions into a table.
        If the output table already exists, the custom function specified
        will be added as a new row into the table. The output table could
        thus act as a repository of Keras custom functions.

        For more details on function usage:
        SELECT {schema_madlib}.{method}('usage')
        """.format(**locals())

        usage = """
        ---------------------------------------------------------------------------
                                        USAGE
        ---------------------------------------------------------------------------
        SELECT {schema_madlib}.{method}(
            object_table,       --  VARCHAR. Output table to load custom function.
            object,             --  BYTEA. dill pickled object of the function definition.
            name,               --  TEXT. Free text string to identify a name
            description         --  TEXT. Free text string to provide a description
        );


        ---------------------------------------------------------------------------
                                        OUTPUT
        ---------------------------------------------------------------------------
        The output table produced by load_custom_function contains the following columns:

        'id'                    -- SERIAL. Function ID.
        'name'                  -- TEXT PRIMARY KEY. unique function name.
        'description'           -- TEXT. function description.
        'object'                -- BYTEA. dill pickled function object.

        """.format(**locals())

        return KerasCustomFunctionDocumentation._returnHelpMsg(
            schema_madlib, message, summary, usage, method)
    # ---------------------------------------------------------------------

    @staticmethod
    def delete_custom_function_help(schema_madlib, message):
        method = "delete_custom_function"
        summary = """
        ----------------------------------------------------------------
                            SUMMARY
        ----------------------------------------------------------------
        Delete the custom function corresponding to the provided id
        from the custom function repository table (object_table).

        For more details on function usage:
        SELECT {schema_madlib}.{method}('usage')
        """.format(**locals())

        usage = """
        ---------------------------------------------------------------------------
                                        USAGE
        ---------------------------------------------------------------------------
        SELECT {schema_madlib}.{method}(
            object_table     VARCHAR, -- Table containing keras custom function objects.
            id               INTEGER  -- The id of the keras custom function object
                                         to be deleted.
        );

        SELECT {schema_madlib}.{method}(
            object_table     VARCHAR, -- Table containing keras custom function objects.
            name             TEXT     -- Function name of the keras custom function
                                         object to be deleted.
        );

        ---------------------------------------------------------------------------
                                        OUTPUT
        ---------------------------------------------------------------------------
        This method deletes the row corresponding to the given id in the
        object_table. This also tries to drop the table if the table is
        empty after dropping the id. If there are any views depending on the
        table, a warning message is displayed and the table is not dropped.

        ---------------------------------------------------------------------------
        """.format(**locals())

        return KerasCustomFunctionDocumentation._returnHelpMsg(
            schema_madlib, message, summary, usage, method)

    @staticmethod
    def load_top_k_accuracy_function_help(schema_madlib, message):
        method = "load_top_k_accuracy_function"
        summary = """
        ----------------------------------------------------------------
                            SUMMARY
        ----------------------------------------------------------------
        The user can specify a custom n value for top_n_accuracy metric.
        If the output table already exists, the custom function specified
        will be added as a new row into the table. The output table could
        thus act as a repository of Keras custom functions.

        For more details on function usage:
        SELECT {schema_madlib}.{method}('usage')
        """.format(**locals())

        usage = """
        ---------------------------------------------------------------------------
                                        USAGE
        ---------------------------------------------------------------------------
        SELECT {schema_madlib}.{method}(
            object_table,       --  VARCHAR. Output table to load custom function.
            k                   --  INTEGER. The number of samples for top n accuracy
        );


        ---------------------------------------------------------------------------
                                        OUTPUT
        ---------------------------------------------------------------------------
        The output table produced by load_top_k_accuracy_function contains the following columns:

        'id'                    -- SERIAL. Function ID.
        'name'                  -- TEXT PRIMARY KEY. unique function name.
        'description'           -- TEXT. function description.
        'object'                -- BYTEA. dill pickled function object.

        """.format(**locals())

        return KerasCustomFunctionDocumentation._returnHelpMsg(
            schema_madlib, message, summary, usage, method)
    # ---------------------------------------------------------------------
