# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
from utilities.utilities import add_postfix
from utilities.utilities import unique_string
from utilities.utilities import is_platform_pg
from utilities.validate_args import table_exists
import plpy


############### Constants used in other deep learning files #########
# Name of columns in model summary table.
CLASS_VALUES_COLNAME = "class_values"
NORMALIZING_CONST_COLNAME = "normalizing_const"
COMPILE_PARAMS_COLNAME = "compile_params"
DEPENDENT_VARNAME_COLNAME = "dependent_varname"
DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
INDEPENDENT_VARNAME_COLNAME = "independent_varname"
MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
MODEL_ARCH_ID_COLNAME = "model_arch_id"
MODEL_WEIGHTS_COLNAME = "model_weights"
METRIC_TYPE_COLNAME = "metrics_type"

# Name of independent, dependent and distribution key colnames in batched table.
# These are readonly variables, do not modify.
# MADLIB-1300 Adding these variables for DL only at this time.
MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL = "dependent_var"
MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL = "independent_var"
DISTRIBUTION_KEY_COLNAME = "__dist_key__"
## sql variable types
FLOAT32_SQL_TYPE = 'REAL'
SMALLINT_SQL_TYPE = 'SMALLINT'

DEFAULT_NORMALIZING_CONST = 1.0

#####################################################################

# Prepend a dimension to np arrays using expand_dims.
def expand_input_dims(input_data):
    input_data = np.array(input_data, dtype=np.float32)
    input_data = np.expand_dims(input_data, axis=0)
    return input_data

def np_array_float32(var, var_shape):
    arr = np.frombuffer(var, dtype=np.float32)
    arr.shape = var_shape
    return arr

def np_array_int16(var, var_shape):
    arr = np.frombuffer(var, dtype=np.int16)
    arr.shape = var_shape
    return arr

def strip_trailing_nulls_from_class_values(class_values):
    """
        class_values is a list of unique class levels in training data. This
        could have multiple Nones in it, and this function strips out all the
        Nones that occur after the first element in the list.
        Examples:
            1) input class_values = ['cat', 'dog']
               output class_values = ['cat', 'dog']

            2) input class_values = [None, 'cat', 'dog']
               output class_values = [None, 'cat', 'dog']

            3) input class_values = [None, 'cat', 'dog', None, None]
               output class_values = [None, 'cat', 'dog']

            4) input class_values = ['cat', 'dog', None, None]
               output class_values = ['cat', 'dog']

            5) input class_values = [None, None]
               output class_values = [None]
        @args:
            @param: class_values, list
        @returns:
            updated class_values list
    """
    num_of_valid_class_values = 0
    if class_values is not None:
        for ele in class_values:
            if ele is None and num_of_valid_class_values > 0:
                break
            num_of_valid_class_values += 1
        # Pass only the valid class_values for creating columns
        class_values = class_values[:num_of_valid_class_values]
    return class_values

def get_image_count_per_seg_from_array(current_seg_id, seg_ids, images_per_seg):
    """
    Get the image count from the array containing all the images
    per segment. Based on the platform, we find the index of the current segment.
    This function is only called from inside the transition function.
    """
    if is_platform_pg():
        total_images = images_per_seg[0]
    else:
        total_images = images_per_seg[seg_ids.index(current_seg_id)]
    return total_images

def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
    """
    Query the given minibatch formatted table and return the total rows per segment.
    Since we cannot pass a dictionary to the keras fit step function we create
    arrays out of the segment numbers and the rows per segment values.
    This function assumes that the table is not empty and is minibatched which means
    that it would have been distributed by __dist_key__.
    :param table_name:
    :return: Returns two arrays
    1. An array containing all the segment numbers in ascending order
    1. An array containing the total images on each of the segments in the
    segment array.
    """

    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL

    shape_col = add_postfix(mb_dep_var_col, "_shape")

    if is_platform_pg():
        res = plpy.execute(
            """ SELECT {0}::SMALLINT[] AS shape
                FROM {1}
            """.format(shape_col, table_name))
        images_per_seg = [sum(r['shape'][0] for r in res)]
        seg_ids = [0]
    else:
        # The number of images in the buffer is the first dimension in the shape.
        # Using __dist_key__ instead of gp_segment_id: Since gp_segment_id is
        # not the actual distribution key column, the optimizer/planner
        # generates a plan with Redistribute Motion, creating multiple slices on
        # each segment. For DL, since GPU memory allocation is tied to the process
        # where it is initialized, we want to minimize creating any additional
        # slices per segment. This is mainly to avoid any GPU memory allocation
        # failures which can occur when a newly created slice(process) tries
        # allocating GPU memory which is already allocated by a previously
        # created slice(process).
        # Since the minibatch_preprocessor evenly distributes the data with __dist_key__
        # as the input table's distribution key, using this for calculating
        # total images on each segment will avoid creating unnecessary slices(processes).
        images_per_seg = plpy.execute(
            """ SELECT {0}, sum({1}[1]) AS images_per_seg
                FROM {2}
                GROUP BY {0}
            """.format(DISTRIBUTION_KEY_COLNAME, shape_col, table_name))
        seg_ids = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
                   for each_segment in images_per_seg]
        images_per_seg = [int(each_segment["images_per_seg"])
                          for each_segment in images_per_seg]

    return seg_ids, images_per_seg

def get_image_count_per_seg_for_non_minibatched_data_from_db(table_name):
    """
    Query the given non minibatch formatted table and return the total rows per segment.
    Since we cannot pass a dictionary to the keras fit step function we create arrays
    out of the segment numbers and the rows per segment values.
    This function assumes that the table is not empty.
    :param table_name:
    :return: gp segment id col name and two arrays
    1. An array containing all the segment numbers in ascending order
    2. An array containing the total rows for each of the segments in the
    segment array
    """
    if is_platform_pg():
        images_per_seg = plpy.execute(
            """ SELECT count(*) AS images_per_seg
                FROM {0}
            """.format(table_name))
        seg_ids = [0]
        gp_segment_id_col = '0'
    else:
        # Compute total buffers on each segment
        images_per_seg = plpy.execute(
            """ SELECT gp_segment_id, count(*) AS images_per_seg
                FROM {0}
                GROUP BY gp_segment_id
            """.format(table_name))
        seg_ids = [int(image["gp_segment_id"]) for image in images_per_seg]
        gp_segment_id_col = '{0}.gp_segment_id'.format(table_name)

    images_per_seg = [int(image["images_per_seg"]) for image in images_per_seg]
    return gp_segment_id_col, seg_ids, images_per_seg

def parse_shape(shape):
    # Parse the shape format given by the sql into an int array
    # [1:10][1:32][1:3] -> [10, 32, 3]
    # Split on :, discard the first one [1:],
    # split each piece on ], take the first piece [0], convert to int
    return [int(a.split(']')[0]) for a in shape.split(':')[1:]]


def query_model_configs(model_selection_table, model_selection_summary_table,
    mst_key_col, model_arch_table_col):
    msts_query = """
                 SELECT * FROM {model_selection_table}
                 ORDER BY {mst_key_col}
                 """.format(**locals())
    model_arch_table_query = """
                             SELECT {model_arch_table_col}
                             FROM {model_selection_summary_table}
                             """.format(**locals())
    msts = list(plpy.execute(msts_query))
    model_arch_table = plpy.execute(model_arch_table_query)[0][model_arch_table_col]
    return msts, model_arch_table

def query_dist_keys(source_table, dist_key_col):
    """ Read distinct keys from the source table """
    dist_key_query = """
                     SELECT DISTINCT({dist_key_col}) FROM {source_table}
                     ORDER BY {dist_key_col}
                     """.format(dist_key_col=dist_key_col,
                                source_table=source_table)
    res = list(plpy.execute(dist_key_query))
    res = [x[dist_key_col] for x in res]
    return res

def query_weights(model_output_table, model_weights_col, mst_key_col, mst_key):
    mlp_weights_query = """
                        SELECT {model_weights_col}, {mst_key_col}
                        FROM {model_output_table}
                        WHERE {mst_key_col} = {mst_key}
                        """.format(**locals())

    res = plpy.execute(mlp_weights_query)
    return res[0][model_weights_col]

def create_summary_view(module_name, model_table, mst_key):
    tmp_view_summary = unique_string('tmp_view_summary')
    model_summary_table = add_postfix(model_table, "_summary")
    model_info_table = add_postfix(model_table, "_info")
    if not (table_exists(model_summary_table) and
            table_exists(model_info_table)):
        plpy.error("{0}: Missing summary and/or info tables for {1}".format(
            module_name, model_table))

    res = plpy.execute("""
        SELECT mst_key FROM {model_info_table} WHERE mst_key = {mst_key}
        """.format(**locals()))
    if len(res) < 1:
        plpy.error("{0}: mst_key {1} does not exist in the info table".format(
            module_name, mst_key))

    # Since fit multiple does not have a model arch id, we set its value to -1.
    # Otherwise, the model arch validation will fail.
    # This aproach is chosen in case we decide to support model arch id in the future.
    plpy.execute("""
        CREATE VIEW {tmp_view_summary} AS
        SELECT *, -1::SMALLINT AS {model_arch_id_colname}
        FROM {model_summary_table}, {model_info_table}
        WHERE mst_key = {mst_key}
        """.format(model_arch_id_colname=MODEL_ARCH_ID_COLNAME, **locals()))
    return tmp_view_summary
