# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


"""
@file input_data_preprocessor.py_in

"""
from math import ceil
import plpy

from internal.db_utils import get_distinct_col_levels
from internal.db_utils import quote_literal
from internal.db_utils import get_product_of_dimensions
from utilities.minibatch_preprocessing import MiniBatchBufferSizeCalculator
from utilities.control import OptimizerControl
from utilities.control import HashaggControl
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import is_platform_pg
from utilities.utilities import is_psql_char_type
from utilities.utilities import is_valid_psql_type
from utilities.utilities import is_var_valid
from utilities.utilities import BOOLEAN, NUMERIC, ONLY_ARRAY, TEXT
from utilities.utilities import py_list_to_sql_string
from utilities.utilities import split_quoted_delimited_str
from utilities.utilities import strip_end_quotes
from utilities.utilities import unique_string
from utilities.utilities import validate_module_input_params
from utilities.utilities import get_seg_number
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import get_expr_type

from madlib_keras_helper import *
import time

NUM_CLASSES_COLNAME = "num_classes"
class DistributionRulesOptions:
    ALL_SEGMENTS = 'all_segments'
    GPU_SEGMENTS = 'gpu_segments'

class InputDataPreprocessorDL(object):
    def __init__(self, schema_madlib, source_table, output_table,
                 dependent_varname, independent_varname, buffer_size,
                 normalizing_const, num_classes, distribution_rules, module_name):
        self.schema_madlib = schema_madlib
        self.source_table = source_table
        self.output_table = output_table
        self.dependent_varname = split_quoted_delimited_str(dependent_varname)
        self.independent_varname = split_quoted_delimited_str(independent_varname)
        self.buffer_size = buffer_size
        self.normalizing_const = normalizing_const
        self.num_classes = num_classes
        self.distribution_rules = distribution_rules.lower() if distribution_rules else DistributionRulesOptions.ALL_SEGMENTS
        self.module_name = module_name
        self.output_summary_table = None
        self.dependent_vartype = None
        self.independent_vartype = None
        self.gpu_config = '$__madlib__${0}$__madlib__$'.format(DistributionRulesOptions.ALL_SEGMENTS)
        if self.output_table:
            self.output_summary_table = add_postfix(self.output_table, "_summary")

        ## Validating input args prior to using them in _set_validate_vartypes()
        self._validate_args()
        self._set_validate_vartypes()
        self.dependent_levels = None
        # The number of padded zeros to include in 1-hot vector
        self.padding_size = 0

    def _set_one_hot_encoding_variables(self):
        """
            Set variables such as dependent_levels and padding_size.
            If necessary, NULLs are padded to dependent_levels list.
        """
        if self.dependent_levels:
            self.padding_size = []
            for i in range(len(self.dependent_levels)):
                tmp_levels = self.dependent_levels[i]
                if tmp_levels:
                # if any class level was NULL in sql, that would show up as
                # None in self.dependent_levels. Replace all None with NULL
                # in the list.
                    self.dependent_levels[i] = ['NULL' if level is None else level
                        for level in tmp_levels]
                    self._validate_num_classes()
                    # Try computing padding_size after running all necessary validations.

                    if self.num_classes:
                        self.padding_size.append(self.num_classes[i] - len(self.dependent_levels[i]))

    def _validate_num_classes(self):
        if self.num_classes is not None:
            for i in range(len(self.num_classes)):
                if self.num_classes[i] < len(self.dependent_levels[i]):
                    plpy.error("{0}: Invalid num_classes value specified. It must "\
                        "be equal to or greater than distinct class values found "\
                        "in table ({1}).".format(
                            self.module_name, len(self.dependent_levels[i])))

    def _validate_distribution_table(self):

        input_tbl_valid(self.distribution_rules, self.module_name,
                        error_suffix_str="""
                        segments_to_use table ({self.distribution_rules}) doesn't exist.
                        """.format(self=self))
        _assert(is_var_valid(self.distribution_rules, 'dbid'),
                "{self.module_name}: distribution rules table must contain dbib column".format(
                    self=self))
        dbids = plpy.execute("""
            SELECT array_agg(dbid) AS dbids FROM gp_segment_configuration
            WHERE content >= 0 AND role = 'p'
            """)[0]['dbids']
        dist_result = plpy.execute("""
            SELECT array_agg(dbid) AS dbids,
                   count(dbid) AS c1,
                   count(DISTINCT dbid) AS c2
            FROM {0} """.format(self.distribution_rules))

        _assert(dist_result[0]['c1'] == dist_result[0]['c2'],
            '{self.module_name}: distribution rules table contains duplicate dbids'.format(
                self=self))

        for i in dist_result[0]['dbids']:
            _assert(i in dbids,
                '{self.module_name}: invalid dbid:{i} in the distribution rules table'.format(
                    self=self, i=i))

    def get_one_hot_encoded_dep_var_expr(self):
        """
        :param dependent_varname: Name of the dependent variable
        :param num_classes: Number of class values to consider in 1-hot
        :return:
            This function returns a tuple of
            1. A string with transformed dependent varname depending on it's type
            2. All the distinct dependent class levels encoded as a string

            If dep_type == numeric[] , do not encode
                    1. dependent_varname = rings
                        transformed_value = ARRAY[rings]
                    2. dependent_varname = ARRAY[a, b, c]
                        transformed_value = ARRAY[a, b, c]
            else if dep_type in ("text", "boolean"), encode:
                    3. dependent_varname = rings (encoding)
                        transformed_value = ARRAY[rings=1, rings=2, rings=3]
        """
        # Assuming the input NUMERIC[] is already one_hot_encoded,
        # so casting to INTEGER[]
        return_sql = []
        for i in range(len(self.dependent_vartype)):

            tmp_type = self.dependent_vartype[i]
            tmp_varname = self.dependent_varname[i]
            tmp_levels = self.dependent_levels[i]
            if is_valid_psql_type(tmp_type, NUMERIC | ONLY_ARRAY):
                return_sql.append("{0}::{1}[]".format(tmp_varname, SMALLINT_SQL_TYPE))
            else:

                # For DL use case, we want to allow NULL as a valid class value,
                # so the query must have 'IS NOT DISTINCT FROM' instead of '='
                # like in the generic get_one_hot_encoded_expr() defined in
                # db_utils.py_in. We also have this optional 'num_classes' param
                # that affects the logic of 1-hot encoding. Since this is very
                # specific to input_preprocessor_dl for now, let's keep
                # it here instead of refactoring it out to a generic helper function.
                one_hot_encoded_expr = ["({0}) IS NOT DISTINCT FROM {1}".format(
                    tmp_varname, c) for c in tmp_levels]
                if self.padding_size:
                    one_hot_encoded_expr.extend(['false'
                        for i in range(self.padding_size[i])])
                # In psql, we can't directly convert boolean to smallint, so we firstly
                # convert it to integer and then cast to smallint
                return_sql.append('ARRAY[{0}]::INTEGER[]::{1}[] AS {2}'.format(
                    ', '.join(one_hot_encoded_expr), SMALLINT_SQL_TYPE, tmp_varname))
        return_sql = ', '.join(return_sql)
        return return_sql

    def _get_var_shape(self, varname):

        shape = plpy.execute(
            "SELECT array_dims({0}) AS shape FROM {1} LIMIT 1".format(
            varname, self.source_table))[0]['shape']
        return parse_shape(shape)

    def _get_independent_var_shape(self):

        shape_list = []
        for i in self.independent_varname:
            shape_list.append(self._get_var_shape(i))
        return shape_list

    def _get_dependent_var_shape(self):

        shape = []
        for counter, dep in enumerate(self.dependent_varname):
            if self.num_classes:
                shape.append(self.num_classes[counter])
            else:
                if self.dependent_levels[counter]:
                    shape.append(len(self.dependent_levels[counter]))
                else:
                    shape = shape + self._get_var_shape(dep)
        return shape


    def input_preprocessor_dl(self, order_by_random=True):
        """
            Creates the output and summary table that does the following
            pre-processing operations on the input data:
            1) Normalizes the independent variable.
            2) Minibatches the normalized independent variable.
            3) One-hot encodes the dependent variable.
            4) Minibatches the one-hot encoded dependent variable.
        """
        # setup for 1-hot encoding
        self._set_one_hot_encoding_variables()

        # Generate random strings for TEMP tables
        series_tbl = unique_string(desp='series')
        dist_key_tbl = unique_string(desp='dist_key')
        normalized_tbl = unique_string(desp='normalized_table')
        batched_table = unique_string(desp='batched_table')

        # Used later in locals() for formatting queries
        x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
        y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
        float32=FLOAT32_SQL_TYPE
        dep_shape_col = add_postfix(y, "_shape")
        ind_shape_col = add_postfix(x, "_shape")

        ind_shape = self._get_independent_var_shape()
        ind_shape = [','.join([str(i) for i in tmp_shape]) for tmp_shape in ind_shape]
        dep_shape = self._get_dependent_var_shape()
        dep_shape = [str(i) for i in dep_shape]
        one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()

        # skip normalization step if normalizing_const = 1.0
        rescale_independent_var = []
        if self.normalizing_const and (self.normalizing_const < 0.999999 or self.normalizing_const > 1.000001):

            for i in self.independent_varname:

                rescale_independent_var.append("""{self.schema_madlib}.array_scalar_mult(
                                                  {i}::{float32}[],
                                                  (1/{self.normalizing_const})::{float32})
                                                  AS {i}_norm
                                               """.format(**locals()))
        else:
            self.normalizing_const = DEFAULT_NORMALIZING_CONST
            for i in self.independent_varname:
                rescale_independent_var.append("{i}::{float32}[] AS {i}_norm".format(**locals()))
        rescale_independent_var = ', '.join(rescale_independent_var)


        # It's important that we shuffle all rows before batching for fit(), but
        #  we can skip that for predict()
        order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""

        concat_sql = []
        shape_sql = []
        bytea_sql = []

        for i,j in zip(self.independent_varname, ind_shape):
            concat_sql.append("""
                {self.schema_madlib}.agg_array_concat(ARRAY[{i}_norm::{float32}[]]) AS {i}
                """.format(**locals()))
            shape_sql.append("""
                ARRAY[count, {j}]::INTEGER[] AS {i}_shape
                """.format(**locals()))
            bytea_sql.append("""
                {self.schema_madlib}.array_to_bytea({i}) AS {i}
                """.format(**locals()))

        for i,j in zip(self.dependent_varname, dep_shape):
            concat_sql.append("""
                {self.schema_madlib}.agg_array_concat(ARRAY[{i}]) AS {i}
                """.format(**locals()))
            shape_sql.append("""
                ARRAY[count, {j}]::INTEGER[] AS {i}_shape
                """.format(**locals()))
            bytea_sql.append("""
                {self.schema_madlib}.array_to_bytea({i}) AS {i}
                """.format(**locals()))

        concat_sql = ', '.join(concat_sql)
        shape_sql = ', '.join(shape_sql)
        bytea_sql = ', '.join(bytea_sql)

        # This query template will be used later in pg & gp specific code paths,
        #  where {make_buffer_id} and {dist_by_buffer_id} are filled in
        batching_query = """
            CREATE TEMP TABLE {batched_table} AS SELECT
                {{make_buffer_id}} buffer_id,
                {concat_sql},
                COUNT(*) AS count
            FROM {normalized_tbl}
            GROUP BY buffer_id
            {{dist_by_buffer_id}}
        """.format(**locals())

        # This query template will be used later in pg & gp specific code paths,
        #  where {dist_key_col_comma} and {dist_by_dist_key} will be filled in
        bytea_query = """
            CREATE TABLE {self.output_table} AS SELECT
                {{dist_key_col_comma}}
                {bytea_sql},
                {shape_sql},
                buffer_id
            FROM {batched_table}
            {{dist_by_dist_key}}
        """.format(**locals())

        if is_platform_pg():
            # used later for writing summary table
            self.distribution_rules = '$__madlib__${0}$__madlib__$'.format(DistributionRulesOptions.ALL_SEGMENTS)

            #
            # For postgres, we just need 3 simple queries:
            #   1-hot-encode/normalize + batching + bytea conversion
            #

            # see note in gpdb code branch (lower down) on
            # 1-hot-encoding of dependent var
            one_hot_sql = """
                CREATE TEMP TABLE {normalized_tbl} AS SELECT
                    (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as row_id,
                    {rescale_independent_var},
                    {one_hot_dep_var_array_expr}
                FROM {self.source_table}
            """.format(**locals())
            plpy.execute(one_hot_sql)

            self.buffer_size = self._get_buffer_size(1)

            # Used to format query templates with locals()
            make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)

            dist_by_dist_key = ''
            dist_by_buffer_id = ''
            dist_key_col_comma = ''

            # Disable hashagg since large number of arrays being concatenated
            # could result in excessive memory usage.
            with HashaggControl(False):
                # Batch rows with GROUP BY
                plpy.execute(batching_query.format(**locals()))

            plpy.execute("DROP TABLE {0}".format(normalized_tbl))

            # Convert to BYTEA and output final (permanent table)
            plpy.execute(bytea_query.format(**locals()))

            plpy.execute("DROP TABLE {0}".format(batched_table))

            self._create_output_summary_table()

            return

        # Done with postgres, rest is all for gpdb
        #
        # This gpdb code path is far more complex, and depends on
        #   how the user wishes to distribute the data.  Even if
        #   it's to be spread evenly across all segments, we still
        #   need to do some extra work to ensure that happens.

        if self.distribution_rules == DistributionRulesOptions.ALL_SEGMENTS:
            all_segments = True
            self.distribution_rules = '$__madlib__${0}$__madlib__$'.format(DistributionRulesOptions.ALL_SEGMENTS)
            num_segments = get_seg_number()
        else:
            all_segments = False

        if self.distribution_rules == DistributionRulesOptions.GPU_SEGMENTS:
            #TODO can we reuse the function `get_accessible_gpus_for_seg` from
            # madlib_keras_helper
            gpu_info_table = unique_string(desp='gpu_info')
            plpy.execute("""
                SELECT {self.schema_madlib}.gpu_configuration('{gpu_info_table}')
            """.format(**locals()))
            gpu_query = """
                SELECT array_agg(DISTINCT(hostname)) as gpu_config
                FROM {gpu_info_table}
            """.format(**locals())
            gpu_query_result = plpy.execute(gpu_query)[0]['gpu_config']
            if not gpu_query_result:
               plpy.error("{self.module_name}: No GPUs configured on hosts.".format(self=self))
            plpy.execute("DROP TABLE IF EXISTS {0}".format(gpu_info_table))

            gpu_config_hostnames = "ARRAY{0}".format(gpu_query_result)
            # find hosts with gpus
            get_segment_query = """
                SELECT array_agg(content) as segment_ids,
                       array_agg(dbid) as dbid,
                       count(*) as count
                FROM gp_segment_configuration
                WHERE content != -1 AND role = 'p'
                AND hostname=ANY({gpu_config_hostnames})
            """.format(**locals())
            segment_ids_result = plpy.execute(get_segment_query)[0]

            self.gpu_config = "ARRAY{0}".format(sorted(segment_ids_result['segment_ids']))
            self.distribution_rules = "ARRAY{0}".format(sorted(segment_ids_result['dbid']))

            num_segments = segment_ids_result['count']

        elif not all_segments:  # Read from a table with dbids to distribute the data
            self._validate_distribution_table()
            gpu_query = """
                SELECT array_agg(content) as gpu_config,
                       array_agg(gp_segment_configuration.dbid) as dbid
                FROM {self.distribution_rules} JOIN gp_segment_configuration
                ON {self.distribution_rules}.dbid = gp_segment_configuration.dbid
            """.format(**locals())
            gpu_query_result = plpy.execute(gpu_query)[0]
            self.gpu_config = "ARRAY{0}".format(sorted(gpu_query_result['gpu_config']))
            num_segments = plpy.execute("SELECT count(*) as count FROM {self.distribution_rules}".format(**locals()))[0]['count']
            self.distribution_rules = "ARRAY{0}".format(sorted(gpu_query_result['dbid']))

        join_key = 't.buffer_id % {num_segments}'.format(**locals())

        if not all_segments:
            join_key = '({self.gpu_config})[{join_key} + 1]'.format(**locals())

        # Create large temp table such that there is atleast 1 row on each segment
        # Using 999999 would distribute data(atleast 1 row on each segment) for
        # a cluster as large as 20000
        dist_key_col = DISTRIBUTION_KEY_COLNAME
        query = """
            CREATE TEMP TABLE {series_tbl} AS
                SELECT generate_series(0, 999999) {dist_key_col}
                DISTRIBUTED BY ({dist_key_col})
            """.format(**locals())

        plpy.execute(query)

        # Used in locals() to format queries, including template queries
        #  bytea_query & batching_query defined in section common to
        #  pg & gp (very beginning of this function)
        dist_by_dist_key = 'DISTRIBUTED BY ({dist_key_col})'.format(**locals())
        dist_by_buffer_id = 'DISTRIBUTED BY (buffer_id)'
        dist_key_col_comma = dist_key_col + ' ,'
        make_buffer_id = ''

        dist_key_query = """
                CREATE TEMP TABLE {dist_key_tbl} AS
                SELECT min({dist_key_col}) AS {dist_key_col}
                FROM {series_tbl}
                GROUP BY gp_segment_id
                DISTRIBUTED BY ({dist_key_col})
        """.format(**locals())

        plpy.execute(dist_key_query)

        plpy.execute("DROP TABLE {0}".format(series_tbl))

        # Always one-hot encode the dependent var. For now, we are assuming
        # that input_preprocessor_dl will be used only for deep
        # learning and mostly for classification. So make a strong
        # assumption that it is only for classification, so one-hot
        # encode the dep var, unless it's already a numeric array in
        # which case we assume it's already one-hot encoded.

        # While 1-hot-encoding is done, we also normalize the independent
        # var and randomly shuffle the rows on each segment.  (The dist key
        # we're adding avoids any rows moving between segments.  This may
        # make things slightly less random, but helps with speed--probably
        # a safe tradeoff to make.)

        norm_tbl = unique_string(desp='norm_table')

        one_hot_sql = """
            CREATE TEMP TABLE {norm_tbl} AS
            SELECT {dist_key_col},
                {rescale_independent_var},
                {one_hot_dep_var_array_expr}
            FROM {self.source_table} s JOIN {dist_key_tbl} AS d
                ON (s.gp_segment_id = d.gp_segment_id)
            {order_by_clause}
            DISTRIBUTED BY ({dist_key_col})
        """.format(**locals())
        plpy.execute(one_hot_sql)

        rows_per_seg_tbl = unique_string(desp='rows_per_seg')
        start_rows_tbl = unique_string(desp='start_rows')

        #  Generate rows_per_segment table; this small table will
        #  just have one row on each segment containing the number
        #  of rows on that segment in the norm_tbl
        sql = """
            CREATE TEMP TABLE {rows_per_seg_tbl} AS SELECT
                COUNT(*) as rows_per_seg,
                {dist_key_col}
            FROM {norm_tbl}
            GROUP BY {dist_key_col}
            DISTRIBUTED BY ({dist_key_col})
        """.format(**locals())

        plpy.execute(sql)

        #  Generate start_rows_tbl from rows_per_segment table.
        #  This assigns a start_row number for each segment based on
        #  the sum of all rows in previous segments.  These will be
        #  added to the row numbers within each segment to get an
        #  absolute index into the table.  All of this is to accomplish
        #  the equivalent of ROW_NUMBER() OVER() on the whole table,
        #  but this way is much faster because we don't have to do an
        #  N:1 Gather Motion (moving entire table to a single segment
        #  and scanning through it).
        #
        sql = """
            CREATE TEMP TABLE {start_rows_tbl} AS SELECT
                {dist_key_col},
                SUM(rows_per_seg) OVER (ORDER BY gp_segment_id) - rows_per_seg AS start_row
            FROM {rows_per_seg_tbl}
            DISTRIBUTED BY ({dist_key_col})
        """.format(**locals())

        plpy.execute(sql)

        plpy.execute("DROP TABLE {0}".format(rows_per_seg_tbl))

        self.buffer_size = self._get_buffer_size(num_segments)

        # The query below assigns slot_id's to each row within
        #  a segment, computes a row_id by adding start_row for
        #  that segment to it, then divides by buffer_size to make
        #  this into a buffer_id
        # ie:
        #  buffer_id = row_id / buffer_size
        #     row_id = start_row + slot_id
        #    slot_id = ROW_NUMBER() OVER(PARTITION BY <dist key>)::INTEGER
        #
        #   Instead of partitioning by gp_segment_id itself, we
        # use __dist_key__ col instead.  This is the same partition,
        # since there's a 1-to-1 mapping between the columns; but
        # using __dist_key__ avoids an extra Redistribute Motion.
        #
        # Note: even though the ordering of these two columns is
        #  different, this doesn't matter as each segment is being
        #  numbered separately (only the start_row is different,
        #  and those are fixed to the correct segments by the JOIN
        #  condition.

        ind_norm_comma_list = ', '.join(["{0}_norm".format(i) for i in self.independent_varname])
        dep_norm_comma_list = ', '.join(self.dependent_varname)
        sql = """
        CREATE TEMP TABLE {normalized_tbl} AS SELECT
            {dist_key_col},
            {ind_norm_comma_list},
            {dep_norm_comma_list},
            (ROW_NUMBER() OVER( PARTITION BY {dist_key_col} ))::INTEGER as slot_id,
            ((start_row +
               (ROW_NUMBER() OVER( PARTITION BY {dist_key_col} ) - 1)
             )::INTEGER / {self.buffer_size}
            ) AS buffer_id
        FROM {norm_tbl} JOIN {start_rows_tbl}
            USING ({dist_key_col})
        ORDER BY buffer_id
        DISTRIBUTED BY (slot_id)
        """.format(**locals())

        plpy.execute(sql)   # label buffer_id's

        # A note on DISTRIBUTED BY (slot_id) in above query:
        #
        #     In the next query, we'll be doing the actual batching.  Due
        #  to the GROUP BY, gpdb will Redistribute on buffer_id.  We could
        #  avoid this by using DISTRIBUTED BY (buffer_id) in the above
        #  (buffer-labelling) query.  But this also causes the GROUP BY
        #  to use single-stage GroupAgg instead of multistage GroupAgg,
        #  which for unknown reasons is *much* slower and often runs out
        #  of VMEM unless it's set very high!

        plpy.execute("DROP TABLE {norm_tbl}, {start_rows_tbl}".format(**locals()))

        # Disable optimizer (ORCA) for platforms that use it
        # since we want to use a groupagg instead of hashagg
        with OptimizerControl(False):
            with HashaggControl(False):
                # Run actual batching query
                plpy.execute(batching_query.format(**locals()))

        plpy.execute("DROP TABLE {0}".format(normalized_tbl))

        if not all_segments: # remove any segments we don't plan to use
            sql = """
                DELETE FROM {dist_key_tbl}
                    WHERE NOT gp_segment_id = ANY({self.gpu_config})
            """.format(**locals())

        plpy.execute("ANALYZE {dist_key_tbl}".format(**locals()))
        plpy.execute("ANALYZE {batched_table}".format(**locals()))

        # Redistribute from buffer_id to dist_key
        #
        #  This has to be separate from the batching query, because
        #   we found that adding DISTRIBUTED BY (dist_key) to that
        #   query causes it to run out of VMEM on large datasets such
        #   as places100.  Possibly this is because the memory available
        #   for GroupAgg has to be shared with an extra slice if they
        #   are part of the same query.
        #
        #  We also tried adding this to the BYTEA conversion query, but
        #   that resulted in slower performance than just keeping it
        #   separate.
        #
        sql = """CREATE TEMP TABLE {batched_table}_dist_key AS
                    SELECT {dist_key_col}, t.*
                        FROM {batched_table} t
                            JOIN {dist_key_tbl} d
                                ON {join_key} = d.gp_segment_id
                            DISTRIBUTED BY ({dist_key_col})
              """.format(**locals())

        # match buffer_id's with dist_keys
        plpy.execute(sql)

        sql = """DROP TABLE {batched_table}, {dist_key_tbl};
                 ALTER TABLE {batched_table}_dist_key RENAME TO {batched_table}
              """.format(**locals())
        plpy.execute(sql)

        # Convert batched table to BYTEA and output as final (permanent) table
        plpy.execute(bytea_query.format(**locals()))

        plpy.execute("DROP TABLE {0}".format(batched_table))

        # Create summary table
        self._create_output_summary_table()

    def _create_output_summary_table(self):
        class_level_str='NULL::{0}[] AS {1}_{2}'.format(self.dependent_vartype[0], self.dependent_varname[0], CLASS_VALUES_COLNAME)
        class_level_list = []
        local_num_classes = []

        for i in range(len(self.dependent_vartype)):
            if self.dependent_levels[i]:
                # Update dependent_levels to include NULL when
                # num_classes > len(self.dependent_levels)
                if self.num_classes:
                    self.dependent_levels[i].extend(['NULL'
                        for j in range(self.padding_size[i])])
                else:
                    local_num_classes.append(str(len(self.dependent_levels[i])))
                class_level_str=py_list_to_sql_string(
                    self.dependent_levels[i], array_type=self.dependent_vartype[i],
                    long_format=True)
                class_level_list.append("{0} AS {1}_{2}".format(class_level_str,
                                                                self.dependent_varname[i],
                                                                CLASS_VALUES_COLNAME))
        class_level_str = ', '.join(class_level_list) if class_level_list else class_level_str
        local_num_classes = ', '.join(local_num_classes)
        if self.num_classes is None:
            self.num_classes = "ARRAY[{0}]::INTEGER[]".format(local_num_classes)
        else:
            self.num_classes = "ARRAY{0}".format(self.num_classes)
        # if self.num_classes is None:
        #     self.num_classes = 'NULL::INTEGER'
        query = """
            CREATE TABLE {self.output_summary_table} AS
            SELECT
                $__madlib__${self.source_table}$__madlib__$::TEXT AS source_table,
                $__madlib__${self.output_table}$__madlib__$::TEXT AS output_table,
                ARRAY{self.dependent_varname} AS {dependent_varname_colname},
                ARRAY{self.independent_varname} AS {independent_varname_colname},
                ARRAY{self.dependent_vartype} AS {dependent_vartype_colname},
                {class_level_str},
                {self.buffer_size} AS buffer_size,
                {self.normalizing_const}::{FLOAT32_SQL_TYPE} AS {normalizing_const_colname},
                {self.num_classes}::INTEGER[] AS {num_classes_colname},
                {self.distribution_rules} AS {distribution_rules},
                {self.gpu_config} AS {internal_gpu_config}
            """.format(self=self, class_level_str=class_level_str,
                       dependent_varname_colname=DEPENDENT_VARNAME_COLNAME,
                       independent_varname_colname=INDEPENDENT_VARNAME_COLNAME,
                       dependent_vartype_colname=DEPENDENT_VARTYPE_COLNAME,
                       class_values_colname=CLASS_VALUES_COLNAME,
                       normalizing_const_colname=NORMALIZING_CONST_COLNAME,
                       num_classes_colname=NUM_CLASSES_COLNAME,
                       internal_gpu_config=INTERNAL_GPU_CONFIG,
                       distribution_rules=DISTRIBUTION_RULES_COLNAME,
                       FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE)
        plpy.execute(query)

    def _validate_args(self):
        validate_module_input_params(
            self.source_table, self.output_table, self.independent_varname,
            self.dependent_varname, self.module_name, None,
            [self.output_summary_table])
        if self.buffer_size is not None:
            _assert(self.buffer_size > 0,
                    "{0}: The buffer size has to be a "
                    "positive integer or NULL.".format(self.module_name))
        if self.normalizing_const is not None:
            _assert(self.normalizing_const > 0,
                "{0}: The normalizing constant has to be a "
                "positive integer or NULL.".format(self.module_name))

    def _set_validate_vartypes(self):
        self.independent_vartype = []

        for i in self.independent_varname:
            self.independent_vartype.append(get_expr_type(i,
                                                          self.source_table))

        self.dependent_vartype = []

        for i in self.dependent_varname:
            self.dependent_vartype.append(get_expr_type(i,
                                                        self.source_table))

    def get_distinct_dependent_levels(self, table, dependent_varname,
                                      dependent_vartype):
        # Refactoring this out into the parent class to ensure include_nulls
        # is passed in as true for both training and validation tables
        return get_distinct_col_levels(table, dependent_varname,
            dependent_vartype, include_nulls=True)

    def _get_buffer_size(self, num_segments):
        num_rows_in_tbl = plpy.execute("""
                SELECT count(*) AS cnt FROM {0}
            """.format(self.source_table))[0]['cnt']
        buffer_size_calculator = MiniBatchBufferSizeCalculator()

        buffer_size = num_rows_in_tbl
        for i in self.independent_varname:
            indepdent_var_dim = get_product_of_dimensions(self.source_table, i)
            tmp_size = buffer_size_calculator.calculate_default_buffer_size(
                self.buffer_size, num_rows_in_tbl, indepdent_var_dim, num_segments)
            buffer_size = min(tmp_size, buffer_size)

        num_buffers = num_segments * ceil((1.0 * num_rows_in_tbl) / buffer_size / num_segments)
        return int(ceil(num_rows_in_tbl / num_buffers))


class ValidationDataPreprocessorDL(InputDataPreprocessorDL):
    def __init__(self, schema_madlib, source_table, output_table,
                 dependent_varname, independent_varname,
                 training_preprocessor_table, buffer_size, distribution_rules,
                 **kwargs):
        """
            This prepares the variables that are required by
            InputDataPreprocessorDL.
        """
        self.module_name = "validation_preprocessor_dl"
        self.training_preprocessor_table = training_preprocessor_table
        summary_table = self._validate_and_process_training_preprocessor_table()
        num_classes = summary_table[NUM_CLASSES_COLNAME]
        InputDataPreprocessorDL.__init__(
            self, schema_madlib, source_table, output_table,
            dependent_varname, independent_varname, buffer_size,
            summary_table[NORMALIZING_CONST_COLNAME], num_classes,
            distribution_rules, self.module_name)
        self.summary_dep_name = summary_table[DEPENDENT_VARNAME_COLNAME]
        self.summary_ind_name = summary_table[INDEPENDENT_VARNAME_COLNAME]
        # Update value of dependent_levels from training batch summary table.
        self.dependent_levels = self._get_dependent_levels(summary_table)

    def _get_dependent_levels(self, summary_table):
        """
            Return the distinct dependent levels to be considered for
            one-hot encoding the dependent var. This is inferred from
            the class_values column in the training_preprocessor_table
            summary table. Note that class_values in that summary table
            already has padding in it, so we have to strip it out here
            in that case.
            This function also quotes class levels if they are text.
        """
        # Validate that dep var type is exactly the same as what was in
        # trainig_preprocessor_table's input.
        training_dependent_vartype = summary_table[DEPENDENT_VARTYPE_COLNAME]

        # training_dependent_levels is the class_values column from the
        # training batch summary table. This already has the padding with
        # NULLs in it based on num_classes that was provided to
        # training_preprocessor_dl(). We have to work our way backwards
        # to strip out those trailing NULLs from class_values, since
        # they will anyway get added later in
        # InputDataPreprocessorDL._set_one_hot_encoding_variables.

        dependent_levels_list = []
        for counter, dep in enumerate(self.summary_dep_name):
            training_dependent_levels = summary_table["{0}_class_values".format(dep)]
            dependent_levels = strip_trailing_nulls_from_class_values(
                training_dependent_levels)
            if training_dependent_levels:
                dependent_levels_val_data = self.get_distinct_dependent_levels(
                    self.source_table,
                    self.dependent_varname[counter],
                    self.dependent_vartype[counter])
                unquoted_dependent_levels_val_data = [strip_end_quotes(level, "'")
                                                      for level in dependent_levels_val_data]
                # Assert to check if the class values in validation data is a subset
                # of the class values in training data.
                _assert(set(unquoted_dependent_levels_val_data).issubset(set(dependent_levels)),
                        "{0}: the class values in {1} ({2}) should be a "
                        "subset of class values in {3} ({4})".format(
                            self.module_name, self.source_table,
                            unquoted_dependent_levels_val_data,
                            self.training_preprocessor_table, dependent_levels))
            if is_psql_char_type(self.dependent_vartype[counter]):
                dependent_levels_list.append([quote_literal(level) if level is not None else level
                                    for level in dependent_levels])
            else:
                dependent_levels_list.append(dependent_levels)
        return dependent_levels_list

    def _validate_and_process_training_preprocessor_table(self):
        """
            Validate training_preprocessor_table param passed. That and
            the corresponding summary tables must exist. The summary
            table must also have columns such as normalizing_const,
            class_values, num_classes and dependent_vartype in it.
        """
        input_tbl_valid(self.training_preprocessor_table, self.module_name)
        training_summary_table = add_postfix(
            self.training_preprocessor_table, "_summary")
        input_tbl_valid(training_summary_table, self.module_name,
                        error_suffix_str="Please ensure that table '{0}' "
                                         "has been preprocessed using "
                                         "training_preprocessor_dl()."
                                        .format(self.training_preprocessor_table))
        summary_table = plpy.execute("SELECT * FROM {0} LIMIT 1".format(
            training_summary_table))[0]
        _assert(NORMALIZING_CONST_COLNAME in summary_table,
            "{0}: Expected column {1} in {2}.".format(
                self.module_name, NORMALIZING_CONST_COLNAME,
                training_summary_table))
        # _assert(CLASS_VALUES_COLNAME in summary_table,
        #     "{0}: Expected column {1} in {2}.".format(
        #         self.module_name, CLASS_VALUES_COLNAME,
        #         training_summary_table))
        _assert(NUM_CLASSES_COLNAME in summary_table,
            "{0}: Expected column {1} in {2}.".format(
                self.module_name, NUM_CLASSES_COLNAME,
                training_summary_table))
        # _assert(DEPENDENT_VARTYPE_COLNAME in summary_table,
        #     "{0}: Expected column {1} in {2}.".format(
        #         self.module_name, DEPENDENT_VARTYPE_COLNAME,
        #         training_summary_table))
        return summary_table

    def validation_preprocessor_dl(self):
        self.input_preprocessor_dl(order_by_random=False)

class TrainingDataPreprocessorDL(InputDataPreprocessorDL):
    def __init__(self, schema_madlib, source_table, output_table,
                 dependent_varname, independent_varname, buffer_size,
                 normalizing_const, num_classes, distribution_rules,
                **kwargs):
        """
            This prepares the variables that are required by
            InputDataPreprocessorDL.
        """
        self.module_name = "training_preprocessor_dl"
        InputDataPreprocessorDL.__init__(
            self, schema_madlib, source_table, output_table,
            dependent_varname, independent_varname, buffer_size,
            normalizing_const, num_classes, distribution_rules,
            self.module_name)
        # Update default value of dependent_levels in superclass
        self.dependent_levels = self._get_dependent_levels()

    def _get_dependent_levels(self):
        """
            Return the distinct dependent levels to be considered for
            one-hot encoding the dependent var. class level values of
            type text are quoted.
        """
        dependent_levels = []
        for i in range(len(self.dependent_varname)):
            tmp_type = self.dependent_vartype[i]
            tmp_varname = self.dependent_varname[i]

            if is_valid_psql_type(tmp_type, NUMERIC | ONLY_ARRAY):
                dependent_levels.append(None)
            else:
                dependent_levels.append(get_distinct_col_levels(
                    self.source_table, tmp_varname,
                    tmp_type, include_nulls=True))
        return dependent_levels

    def training_preprocessor_dl(self):
        self.input_preprocessor_dl(order_by_random=True)

class InputDataPreprocessorDocumentation:
    @staticmethod
    def validation_preprocessor_dl_help(schema_madlib, message):
        method = "validation_preprocessor_dl"
        summary = """
        ----------------------------------------------------------------
                            SUMMARY
        ----------------------------------------------------------------
        For Deep Learning based techniques such as Convolutional Neural Nets,
        the input data is mostly images. These images can be represented as an
        array of numbers where each element represents a pixel/color intensity.
        It is standard practice to normalize the image data before use.
        minibatch_preprocessor() is for general use-cases, but for deep learning
        based use-cases we provide training_preprocessor_dl() that is
        light-weight and is specific to image datasets.

        If you want to evaluate the model, a validation dataset has to
        be prepared. This validation data has to be in the same format
        as the corresponding batched training data used for training, i.e.,
        the two datasets must be normalized using the same normalizing
        constant, and the one-hot encoding of the dependent variable must
        follow the same convention. validation_preprocessor_dl() can be
        used to pre-process the validation data. To ensure that the format
        is similar to the corresponding training data, this function takes
        the output table name of training_preprocessor_dl() as an input
        param.

        For more details on function usage:
        SELECT {schema_madlib}.{method}('usage')
        """.format(**locals())

        usage = """
        ---------------------------------------------------------------------------
                                        USAGE
        ---------------------------------------------------------------------------
        SELECT {schema_madlib}.{method}(
            source_table,          -- TEXT. Name of the table containing input
                                      data.  Can also be a view.
            output_table,          -- TEXT. Name of the output table for
                                      mini-batching.
            dependent_varname,     -- TEXT. Name of the dependent variable column.
            independent_varname,   -- TEXT. Name of the independent variable
                                      column.
            training_preprocessor_table, -- TEXT. packed training data table.
            buffer_size            -- INTEGER. Default computed automatically.
                                      Number of source input rows to pack into a buffer.
            distribution_rules     -- TEXT. Default: 'all_segments'. Specifies how to
                                      distribute the 'output_table'. This is important
                                      for how the fit function will use resources on the
                                      cluster.  The default 'all_segments' means the
                                      'output_table' will be distributed to all segments
                                      in the database cluster.
        );


        ---------------------------------------------------------------------------
                                        OUTPUT
        ---------------------------------------------------------------------------
        The output table produced by validation_preprocessor_dl contains the
        following columns:

        buffer_id               -- INTEGER.  Unique id for packed table.
        dependent_varname       -- BYTEA. Packed array of dependent variables.
        independent_varname     -- BYTEA. Packed array of independent
                                   variables.
        dependent_varname       -- TEXT. Shape of the dependent variable buffer.
        independent_varname     -- TEXT. Shape of the independent variable buffer.

        ---------------------------------------------------------------------------
        The algorithm also creates a summary table named <output_table>_summary
        that has the following columns:

        source_table              -- Source table name.
        output_table              -- Output table name from preprocessor.
        dependent_varname         -- Dependent variable values from the original table
                                     (encoded by one_hot_encode, if specified).
        independent_varname       -- Independent variable values from the original
                                     table.
        dependent_vartype         -- Type of the dependent variable from the
                                     original table.
        class_values              -- Class values of the dependent variable
                                     (‘NULL’(as TEXT type) for non
                                     categorical vars).
        buffer_size               -- Buffer size used in preprocessing step.
        normalizing_const         -- Normalizing constant used for standardizing.
                                     arrays in independent_varname.
        num_classes               -- num_classes value passed by user while
                                     generating training_preprocessor_table.
        gpu_config                -- List of segment id's the data is distributed
                                     on depending on the 'distribution_rules' parameter
                                     specified as input. Set to 'all_segments' if
                                     'distribution_rules' is specified as 'all_segments'.

        ---------------------------------------------------------------------------
        """.format(**locals())

        if not message:
            return summary
        elif message.lower() in ('usage', 'help', '?'):
            return usage
        return """
            No such option. Use "SELECT {schema_madlib}.{method}()"
            for help.
        """.format(**locals())

    @staticmethod
    def training_preprocessor_dl_help(schema_madlib, message):
        method = "training_preprocessor_dl"
        summary = """
        ----------------------------------------------------------------
                            SUMMARY
        ----------------------------------------------------------------
        For Deep Learning based techniques such as Convolutional Neural Nets,
        the input data is mostly images. These images can be represented as an
        array of numbers where each element represents a pixel/color intensity.
        It is standard practice to normalize the image data before use.
        minibatch_preprocessor() is for general use-cases, but for deep learning
        based use-cases we provide training_preprocessor_dl() that is
        light-weight and is specific to image datasets.

        The normalizing constant is parameterized, and can be specified based
        on the kind of image data used.

        An optional param named num_classes can be used to specify the length
        of the one-hot encoded array for the dependent variable. This value if
        specified must be greater than equal to the total number of distinct
        class values found in the input table.

        For more details on function usage:
        SELECT {schema_madlib}.{method}('usage')
        """.format(**locals())

        usage = """
        ---------------------------------------------------------------------------
                                        USAGE
        ---------------------------------------------------------------------------
        SELECT {schema_madlib}.{method}(
            source_table,          -- TEXT. Name of the table containing input
                                      data.  Can also be a view.
            output_table,          -- TEXT. Name of the output table for
                                      mini-batching.
            dependent_varname,     -- TEXT. Name of the dependent variable column.
            independent_varname,   -- TEXT. Name of the independent variable
                                      column.
            buffer_size            -- INTEGER. Default computed automatically.
                                      Number of source input rows to pack into a buffer.
            normalizing_const      -- REAL. Default 1.0. The normalizing constant to
                                      use for standardizing arrays in independent_varname.
            num_classes            -- INTEGER. Default NULL. Number of class labels
                                      to be considered for 1-hot encoding. If NULL,
                                      the 1-hot encoded array length will be equal to
                                      the number of distinct class values found in the
                                      input table.
            distribution_rules     -- TEXT. Default: 'all_segments'. Specifies how to
                                      distribute the 'output_table'. This is important
                                      for how the fit function will use resources on the
                                      cluster.  The default 'all_segments' means the
                                      'output_table' will be distributed to all segments
                                      in the database cluster.
        );


        ---------------------------------------------------------------------------
                                        OUTPUT
        ---------------------------------------------------------------------------
        The output table produced by MiniBatch Preprocessor contains the
        following columns:

        buffer_id               -- INTEGER.  Unique id for packed table.
        dependent_varname       -- BYTEA. Packed array of dependent variables.
        independent_varname     -- BYTEA. Packed array of independent
                                   variables.
        dependent_varname       -- TEXT. Shape of the dependent variable buffer.
        independent_varname     -- TEXT. Shape of the independent variable buffer.

        ---------------------------------------------------------------------------
        The algorithm also creates a summary table named <output_table>_summary
        that has the following columns:

        source_table              -- Source table name.
        output_table              -- Output table name from preprocessor.
        dependent_varname         -- Dependent variable values from the original table
                                     (encoded by one_hot_encode, if specified).
        independent_varname       -- Independent variable values from the original
                                     table.
        dependent_vartype         -- Type of the dependent variable from the
                                     original table.
        class_values              -- Class values of the dependent variable
                                     (‘NULL’(as TEXT type) for non
                                     categorical vars).
        buffer_size               -- Buffer size used in preprocessing step.
        normalizing_const         -- Normalizing constant used for standardizing
                                     arrays in independent_varname.
        num_classes               -- num_classes input param passed to function.
        gpu_config                -- List of segment id's the data is distributed
                                     on depending on the 'distribution_rules' param
                                     specified as input. Set to 'all_segments' if
                                     'distribution_rules' is specified as 'all_segments'.


        ---------------------------------------------------------------------------
        """.format(**locals())

        if not message:
            return summary
        elif message.lower() in ('usage', 'help', '?'):
            return usage
        return """
            No such option. Use "SELECT {schema_madlib}.{method}()"
            for help.
        """.format(**locals())
