# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file EXCEPT in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# m4_changequote(`<!', `!>')

import math
from collections import defaultdict

if __name__ != "__main__":
    import plpy
    from utilities.control import MinWarning
    from utilities.utilities import _assert
    from utilities.utilities import extract_keyvalue_params
    from utilities.utilities import unique_string
    from utilities.utilities import collate_plpy_result
    from utilities.utilities import get_grouping_col_str
    from utilities.validate_args import columns_exist_in_table
    from utilities.validate_args import explicit_bool_to_text
    from utilities.validate_args import get_cols
    from utilities.validate_args import table_exists
    from utilities.validate_args import table_is_empty
else:
    # Used only for Unit Testing
    # FIXME: repeating a function from utilities that is needed by the unit test.
    # This should be removed once a unittest framework in used for testing.
    import random
    import time

    def unique_string(desp='', **kwargs):
        """
        Generate random remporary names for temp table and other names.
        It has a SQL interface so both SQL and Python functions can call it.
        """
        r1 = random.randint(1, 100000000)
        r2 = int(time.time())
        r3 = int(time.time()) % random.randint(1, 100000000)
        u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
        return u_string
# ------------------------------------------------------------------------------

UNIFORM = 'uniform'
UNDERSAMPLE = 'undersample'
OVERSAMPLE = 'oversample'
NOSAMPLE = 'nosample'

NEW_ID_COLUMN = '__madlib_id__'
NULL_IDENTIFIER = '__madlib_null_id__'


def _get_level_frequency_distribution(source_table, class_col,
                                      grp_by_cols=None):
    """ Count the number of rows for each class, partitioned by the grp_by_cols

        Returns a dict containing the number of rows associated with each class
        level. Each class level count is converted to a string using ::text.
        None is a valid key in this dict, capturing NULL value in the database.
    """
    if grp_by_cols and grp_by_cols.lower() != 'null':
        is_grouping = True
        grp_by_cols_comma = grp_by_cols + ', '
        array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS group_values, "
    else:
        is_grouping = False
        grp_by_cols_comma = array_grp_by_cols_comma = ""

    # In below query, the inner query groups the data using grp_by_cols + classes
    # and obtains the count for each combination. The outer query then groups
    # again by the grp_by_cols to collect the classes and counts in an array.
    query_result = plpy.execute("""
        SELECT
            -- For each group get the classes and their rows counts
            {grp_identifier} as group_values,
            array_agg(classes) as classes,
            array_agg(class_count) as class_count
        FROM(
            -- for each group and class combination present in source table
            -- get the count of rows for that combination
            SELECT
                {array_grp_by_cols_comma}
                ({class_col})::TEXT AS classes,
                count(*) AS class_count
            FROM {source_table}
            GROUP BY {grp_by_cols_comma} ({class_col})
        ) q
        {meta_grp_by}
     """.format(grp_identifier="group_values" if is_grouping else "NULL",
                meta_grp_by="GROUP BY group_values" if is_grouping else "",
                **locals()))
    if (len(query_result) > 1) != is_grouping:
        # if is_grouping then query_result should have more than 1 row
        # if not is_grouping then query_result should have only 1 row
        raise RuntimeError("Balance sample: Error during frequency level distribution")

    actual_grp_level_counts = {}
    for each_row in query_result:
        # group_values is a list for each row; convert it to a tuple to use as
        # key in a dictionary
        grp = tuple(each_row['group_values']) if is_grouping else None
        grp_levels, grp_counts = each_row['classes'], each_row['class_count']
        actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts))
    return actual_grp_level_counts
# ------------------------------------------------------------------------------


def _validate_and_get_sampling_strategy(sampling_strategy_str,
                                        output_table_size,
                                        default=UNIFORM):
    """ Returns the sampling strategy based on the class_sizes input param.
        @param sampling_strategy_str The sampling strategy specified by the
                                         user (class_sizes param)
        @returns:
            Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
    """
    if not sampling_strategy_str:
        sampling_strategy_str = default
    else:
        if len(sampling_strategy_str) < 3:
            # Require at least 3 characters since UNIFORM and UNDERSAMPLE have
            # common prefix substring
            plpy.error("Sample: Invalid class_sizes parameter")

        supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE)
        try:
            # allow user to specify a prefix substring of
            # supported strategies.
            sampling_strategy_str = next(x for x in supported_strategies
                                         if x.startswith(sampling_strategy_str.lower()))
        except StopIteration:
            # next() returns a StopIteration if no element found
            plpy.error("Sample: Invalid class_sizes parameter: "
                       "{0}. Supported class_size parameters are ({1})".
                       format(sampling_strategy_str,
                              ','.join(sorted(supported_strategies))))

    _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
            (sampling_strategy_str.find('=') > 0),
            "Sample: Invalid class_sizes parameter: "
            "{0}. Supported class_size parameters are ({1})".
            format(sampling_strategy_str,
                   ','.join(sorted(supported_strategies))))

    _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
            "Sample: Cannot set output_table_size with oversampling.")

    _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
            "Sample: Cannot set output_table_size with undersampling.")

    return sampling_strategy_str
# ------------------------------------------------------------------------------


def _choose_strategy(actual_count, desired_count):
    """ Choose sampling strategy by comparing actual and desired sample counts

    @param actual_count: Actual number of samples for some level
    @param desired_count: Desired number of sample for the level
    @returns:
        Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
    """
    # OVERSAMPLE when the actual count is less than the desired count
    # UNDERSAMPLE when the actual count is more than the desired count

    # If the actual count for a class level is the same as desired count, then
    # we could potentially return the input rows as is.  This, however,
    # precludes the case of bootstrapping (i.e. returning same  number of rows
    # but after sampling with replacement).  Hence, we treat the actual=desired
    # as UNDERSAMPLE.  It's specifically set to UNDERSAMPLE since it provides
    # both 'with' and 'without' replacement  (OVERSAMPLE is always with
    # replacement and NOSAMPLE is always without replacement)
    if actual_count < desired_count:
        return OVERSAMPLE
    else:
        return UNDERSAMPLE
# -------------------------------------------------------------------------


def _get_desired_target_level_counts(desired_level_counts,
                                     actual_grp_level_counts,
                                     output_table_size):
    """ Return the target counts for each group and each class level in the group

        This function is specifically used when the user has provided the
        targets  for all (or a subset) of the levels. The strategy of either
        under or oversampling for each class level is chosen based on the
        desired number of counts for a level, and the actual number of counts
        already present in the input table. This calculation is performed for
        each group (or once if no grouping is used).

        @returns: dict: The key of the dictionary is the group value and the
                        value is another dictionary. The inner dictionary gives
                        the target count for each level present in the group.
    """
    target_grp_level_counts = defaultdict(dict)
    n_grp_values = len(actual_grp_level_counts)
    for each_grp, actual_level_counts in actual_grp_level_counts.items():
        for each_level, desired_count in desired_level_counts.items():
            # actual_grp_level_counts are group specific, whereas
            # desired_level_counts are evenly split among all groups
            try:
                per_grp_desired_count = math.ceil(float(desired_count) / n_grp_values)
                sample_strategy = _choose_strategy(actual_level_counts[each_level],
                                                   per_grp_desired_count)
                target_grp_level_counts[each_grp][each_level] = (
                    per_grp_desired_count, sample_strategy)
            except KeyError:
                plpy.error("Balance sample: Desired class level ({0}) not present in the "
                           "data for each group.".format(each_level))

        # desired levels could contain just a subset of all levels. For the remaining
        # levels in actual_level_counts, compute the desired counts
        remaining_levels = (set(actual_level_counts.keys()) -
                            set(desired_level_counts.keys()))

        #   if 'output_table_size' = NULL, remaining level counts remain as is
        #   if 'output_table_size' = <Integer>, divide remaining count
        #                             uniformly among reamining levels
        if output_table_size:
            # output_table_size is for the whole table and should be split evenly
            # between the groups
            remaining_rows = math.ceil(float(output_table_size -
                                             sum(desired_level_counts.values())) /
                                       n_grp_values)
            if remaining_rows > 0:
                # Uniformly distribute the remaining class levels
                rows_per_level = math.ceil(float(remaining_rows) /
                                           len(remaining_levels))
                for each_level in remaining_levels:
                    sample_strategy = _choose_strategy(
                        actual_level_counts[each_level], rows_per_level)
                    target_grp_level_counts[each_grp][each_level] = (
                        rows_per_level, sample_strategy)
        else:
            # When output_table_size is unspecified, rows from the input table
            # are sampled as is for remaining class levels. This is called as the
            # NOSAMPLE strategy.
            for each_level in remaining_levels:
                target_grp_level_counts[each_grp][each_level] = (
                    actual_level_counts[each_level], NOSAMPLE)
    return target_grp_level_counts
# -------------------------------------------------------------------------


def _get_supported_target_level_counts(sampling_strategy_str,
                                       actual_grp_level_counts,
                                       output_table_size):
    """ Returns the target level counts for all levels when the class_size param
        is one of [uniform, undersample, oversample]. The strategy of
        (under)oversampling for a specific class level is chosen based on the
        computed number of counts for a level, and the actual number of counts
        already present in the input table.

    """
    target_grp_level_counts = defaultdict(dict)
    n_grp_values = len(actual_grp_level_counts)
    for each_grp, actual_level_counts in actual_grp_level_counts.items():
        if sampling_strategy_str == UNIFORM:
            # UNIFORM: Ensure all level counts are same
            if output_table_size:
                # Ignore actual counts for computing target sizes
                # if output_table_size is specified
                total_ = float(output_table_size) / n_grp_values
            else:
                total_ = sum(actual_level_counts.values())
            target_size_per_level = math.ceil(float(total_) /
                                              len(actual_level_counts))
        else:
            # UNDERSAMPLE: Ensure all level counts are same as the minimum count
            # OVERSAMPLE: Ensure all level counts are same as the maximum count
            if sampling_strategy_str == UNDERSAMPLE:
                target_size_per_level = min(actual_level_counts.values())
            elif sampling_strategy_str == OVERSAMPLE:
                target_size_per_level = max(actual_level_counts.values())
            else:
                raise RuntimeError("Balance sample: Invalid "
                                   "sampling_strategy_str encountered")

        for each_level, actual_count in actual_level_counts.items():
            sample_strategy = _choose_strategy(actual_count, target_size_per_level)
            target_grp_level_counts[each_grp][each_level] = (
                target_size_per_level, sample_strategy)
    return target_grp_level_counts
# -------------------------------------------------------------------------


def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
                             actual_grp_level_counts, output_table_size):
    """
    @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
                               This is 'None' only if this is user-defined, i.e.,
                               a comma separated list of class levels and number
                               of rows desired pairs.
    @param desired_level_counts: Dict that is defined only when the previous arg
                                 sampling_strategy_str is None. This dict would
                                 then contain the class levels and the
                                 corresponding number of rows specified by
                                 the user.
    @param actual_grp_level_counts: Dictionary that provides for each group the
                                  the count of number of rows for each class
                                  present in the group
    @param output_table_size: Size of the desired output table (NULL or Integer)

    @returns:
        Dict. Number of samples to be drawn, and the sampling strategy to be
              used for each class level.
    """

    if not sampling_strategy_str:
        # This case implies user has provided a desired count for one or more
        # levels. Counts for rest of the levels depend on 'output_table_size'.
        target_level_counts = _get_desired_target_level_counts(
            desired_level_counts, actual_grp_level_counts, output_table_size)
    else:
        # This case imples the user has chosen one of
        # [uniform, undersample, oversample] for the class_size parameter.
        target_level_counts = _get_supported_target_level_counts(
            sampling_strategy_str, actual_grp_level_counts, output_table_size)
    return target_level_counts
# -------------------------------------------------------------------------


def _get_sampling_strategy_counts(target_class_sizes):
    """ Return three dicts, one each for undersampling, oversampling, and
        nosampling. The dict contains the number of samples to be drawn for
        each class level.
    """
    undersample_level_counts = {}
    oversample_level_counts = {}
    nosample_level_counts = {}
    for level, (count, strategy) in target_class_sizes.items():
        if strategy == UNDERSAMPLE:
            undersample_level_counts[level] = count
        elif strategy == OVERSAMPLE:
            oversample_level_counts[level] = count
        else:
            nosample_level_counts[level] = count
    return (undersample_level_counts, oversample_level_counts, nosample_level_counts)
# ------------------------------------------------------------------------------


def _get_nosample_subquery(source_table, class_col, nosample_levels,
                           grp_dict=None):
    """ Return the subquery for fetching all rows as is from the input table
        for specific class levels.
    """
    if not nosample_levels:
        return ''
    nosample_level_str = ','.join(["'{0}'".format(level)
                                   for level in nosample_levels if level])
    if grp_dict:
        grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
                                            for k, v in grp_dict.items())
    else:
        grp_filter = ''
    return """
        SELECT *
        FROM {source_table}
        WHERE ({class_col} in ({nosample_level_str}) OR
               {class_col} IS NULL)
              {grp_filter}
        """.format(**locals())
# ------------------------------------------------------------------------------


def _get_without_replacement_subquery(schema_madlib, source_table, class_col,
                                      actual_level_counts, desired_level_counts,
                                      grp_dict=None):
    """ Return the subquery for sampling without replacement for specific
        class levels.
    """
    if not desired_level_counts:
        return ''
    class_col_tmp = unique_string(desp='class_col')
    row_number_col = unique_string(desp='row_number')
    desired_count_col = unique_string(desp='desired_count')
    source_table_columns = ','.join(get_cols(source_table))
    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
    desired_level_count_pairs = (','.join("({0}, {1})".
                                 format("'{0}'::text".format(k) if k else null_value_string, v)
                                 for k, v in desired_level_counts.items()))
    desired_level_counts_str = "VALUES " + desired_level_count_pairs
    # Subquery q2 is used to figure out the number of rows to select for each
    # class level. That is used with row_number to order the input rows randomly,
    # and trim the results to the desired number of rows for each class level.
    # q1:
    #    The FROM clause contains information that can be used to figure out the
    #    number of rows to generate for a class level. The tuple basically
    #    contains: (class_level, the_desired_number_of_rows)
    if grp_dict:
        grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
                                            for k, v in grp_dict.items())
    else:
        grp_filter = ''
    subquery = """
            SELECT {source_table_columns}
            FROM (
                    SELECT {source_table_columns},
                           row_number() OVER (PARTITION BY {class_col} ORDER BY random()) AS {row_number_col},
                           {desired_count_col}
                    FROM (
                        SELECT {source_table_columns},
                               {desired_count_col}
                        FROM
                            {source_table} s,
                            ({desired_level_counts_str})
                                q1({class_col_tmp}, {desired_count_col})
                        WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_identifier}')
                              {grp_filter}
                    ) q2
                ) q3
            WHERE {row_number_col} <= {desired_count_col}
        """.format(null_identifier=NULL_IDENTIFIER, **locals())
    return subquery
# ------------------------------------------------------------------------------


def _get_with_replacement_subquery(schema_madlib, source_table, class_col,
                                   actual_level_counts, desired_level_counts,
                                   grp_dict=None):
    """ Return the query for sampling with replacement for specific class
        levels. Always used for oversampling since oversampling will always need
        to use replacement. Used for under sampling only if with_replacement
        flag is set to TRUE.
    """
    if not desired_level_counts:
        return ''

    source_table_columns = ','.join(get_cols(source_table))
    class_col_tmp = unique_string(desp='class_col_with_rep')
    desired_count_col = unique_string(desp='desired_count_with_rep')
    actual_count_col = unique_string(desp='actual_count')
    q1_row_no = unique_string(desp='q1_row')
    q2_row_no = unique_string(desp='q2_row')

    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
    desired_and_actual_count = (','.join("({0}, {1}, {2})".
                                format("'{0}'::text".format(k) if k else null_value_string,
                                       v, actual_level_counts[k])
                                for k, v in desired_level_counts.items()))
    desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count

    # q1 and q2 are two sub queries we create to generate the required number of
    # rows per class level.
    # q1:
    #    The FROM clause contains information that can be used to figure out the
    #    number of rows to generate for a class level. The tuple basically
    #    contains:
    #      (class_level, the_desired_number_of_rows, actual_rows_for_level_in_input_table)
    #    The SELECT clause uses generate series to duplicate a row {q1_row_no}
    #    of times, which is a value between 1 and actual_rows_for_level_in_input_table
    # q2:
    #   Replicates the source_table with row IDs starting from 1 through
    #   actual_rows_for_level_in_input_table.
    #
    # The WHERE clause is used to join the two subqueries to obtain the result.
    if grp_dict:
        grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v)
                                             for k, v in grp_dict.items())
    else:
        grp_filter = ''
    subquery = """
            SELECT {source_table_columns}
            FROM
                (
                    SELECT
                         {class_col_tmp},
                         generate_series(1, {desired_count_col}::int) AS _i,
                         ((random()*({actual_count_col}-1)+1)::int) AS {q1_row_no}
                    FROM
                        ({desired_and_actual_level_count_str})
                            q({class_col_tmp}, {desired_count_col}, {actual_count_col})
                ) q1,
                (
                    SELECT
                        *,
                        row_number() OVER(PARTITION BY {class_col}) AS {q2_row_no}
                    FROM
                         {source_table}
                    {grp_filter}
                ) q2
            WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') AND
                  q1.{q1_row_no} = q2.{q2_row_no}
        """.format(null_level_val=NULL_IDENTIFIER, **locals())
    return subquery
# ------------------------------------------------------------------------------


def balance_sample(schema_madlib, source_table, output_table, class_col,
                   class_sizes, output_table_size, grouping_cols,
                   with_replacement, keep_null, **kwargs):
    """
    Balance sampling function
    Args:
        @param schema_madlib      Schema that MADlib is installed on.
        @param source_table       Input table name.
        @param output_table       Output table name.
        @param class_col          Name of the column containing the class to be
                                  balanced.
        @param class_sizes        Parameter to define the size of the different
                                  class values.
        @param output_table_size  Desired size of the output data set.
        @param grouping_cols      The columns that define the grouping.
        @param with_replacement   The sampling method.
        @param keep_null          Flag to include rows with class level values
                                  NULL. Default is False.

    """
    with MinWarning("warning"):

        desired_sample_per_class = unique_string(desp='desired_sample_per_class')
        desired_counts = unique_string(desp='desired_counts')

        # set all default values
        if not class_sizes:
            class_sizes = UNIFORM
        if not with_replacement:
            with_replacement = False
        keep_null = False if not keep_null else True
        if class_sizes:
            class_sizes = class_sizes.strip()

        _validate_strs(source_table, output_table, class_col,
                       output_table_size, grouping_cols)

        # If keep_null=False, create a view of the input table ignoring NULL
        # values for class levels.
        if keep_null:
            new_source_table = source_table
        else:
            new_source_table = unique_string(desp='source_table')
            plpy.execute("""
                CREATE VIEW {new_source_table} AS
                SELECT * FROM {source_table}
                WHERE {class_col} IS NOT NULL
            """.format(**locals()))
        # class_sizes can be of two forms:
        #   1. A string describing sampling strategy (as described in
        #       _validate_and_get_sampling_strategy).
        #       In this case, 'sampling_strategy_str' is set to one of
        #       [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
        #   2. Class sizes for all (or a subset) of the class levels
        #       In this case, sampling_strategy_str = None and parsed_class_sizes
        #       is used for the sampling.
        parsed_class_sizes = extract_keyvalue_params(class_sizes,
                                                     allow_duplicates=False,
                                                     lower_case_names=False)

        # GREENPLUM 4.3.X does not support bool-to-text cast which is relied
        #  upon for class_col in multiple queries. The explicit_bool_to_text
        #  function wraps the class_col with a MADlib function that provides the
        #  cast just for those platforms that don't provide the cast. For
        #  platforms that provide the cast, class_col is unchanged below.
        class_col = explicit_bool_to_text(source_table,
                                          [class_col],
                                          schema_madlib)[0]
        distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
                        format(class_col,
                               source_table))
        distinct_levels = collate_plpy_result(plpy.execute(distinct_sql))['levels']
        if not parsed_class_sizes:
            sampling_strategy_str = _validate_and_get_sampling_strategy(
                class_sizes, output_table_size)
        else:
            sampling_strategy_str = None
            try:
                for each_level, each_class_size in parsed_class_sizes.items():
                    _assert(each_level in distinct_levels,
                            "Sample: Invalid class value specified ({0})".
                            format(each_level))
                    each_class_size = int(each_class_size)
                    _assert(each_class_size >= 1,
                            "Sample: Class size has to be greater than zero")
                    parsed_class_sizes[each_level] = each_class_size
            except ValueError:
                plpy.error("Sample: Invalid value for class_sizes ({0})".
                           format(class_sizes))

        # Get the number of rows to be sampled for each class level, based on
        # the input table, class_sizes, and output_table_size params. This also
        # includes info about the resulting sampling strategy, i.e., one of
        # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
        grp_col_str, grp_cols = get_grouping_col_str(
            schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
            source_table, grouping_cols)
        actual_grp_level_counts = _get_level_frequency_distribution(
            new_source_table, class_col, grp_col_str)
        target_grp_class_sizes = _get_target_level_counts(
            sampling_strategy_str, parsed_class_sizes,
            actual_grp_level_counts, output_table_size)

        is_output_created = False
        grp_cols_list = grp_col_str.split(',') if grp_cols else []

        # for each group
        for grp_vals, actual_level_counts in actual_grp_level_counts.items():
            target_class_sizes = target_grp_class_sizes[grp_vals]

            (undersample_level_counts,
             oversample_level_counts,
             nosample_level_counts) = _get_sampling_strategy_counts(target_class_sizes)

            # grp_dict represents each grouping column and its value
            # for current iteration
            grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None

            # Get subqueries for each sampling strategy and union together in
            # one big query.

            # NOSAMPLE are levels that are to be retained in output without sampling
            nosample_subquery = _get_nosample_subquery(
                new_source_table, class_col, nosample_level_counts.keys(), grp_dict)
            # OVERSAMPLE are levels that to be sampled more than existing count
            # (always with replacement)
            oversample_subquery = _get_with_replacement_subquery(
                schema_madlib, new_source_table, class_col,
                actual_level_counts, oversample_level_counts, grp_dict)
            # UNDERSAMPLE are levels that are to be sampled to less than
            # existing count. Undersampling supports both with and without
            # replacement.
            if with_replacement:
                undersample_subquery = _get_with_replacement_subquery(
                    schema_madlib, new_source_table, class_col,
                    actual_level_counts, undersample_level_counts, grp_dict)
            else:
                undersample_subquery = _get_without_replacement_subquery(
                    schema_madlib, new_source_table, class_col,
                    actual_level_counts, undersample_level_counts, grp_dict)

            # Merge the three subqueries using a UNION ALL clause.
            sampling_queries = (undersample_subquery, oversample_subquery, nosample_subquery)
            union_all_subquery = ' UNION ALL '.join(
                ['({0})'.format(subquery)
                 for subquery in sampling_queries if subquery])

            # Populate the output table
            if is_output_created:
                table_header = "INSERT INTO {0}".format(output_table)
            else:
                table_header = "CREATE TABLE {0} AS ".format(output_table)
                is_output_created = True
            final_query = """
                {table_header}
                SELECT row_number() OVER() AS {id_col_name}, *
                FROM (
                    {union_all_subquery}
                ) union_query
            """.format(id_col_name=NEW_ID_COLUMN, **locals())
            plpy.execute(final_query)
        # end of grouping loop

        if not keep_null:
            plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table))
# ------------------------------------------------------------------------------


def _validate_strs(source_table, output_table, class_col,
                   output_table_size, grouping_cols):
    _assert(source_table and table_exists(source_table),
            "Sample: Source table ({source_table}) does not exist.".format(**locals()))
    _assert(not table_is_empty(source_table),
            "Sample: Source table ({source_table}) is empty.".format(**locals()))

    _assert(output_table,
            "Sample: Output table name is missing.".format(**locals()))
    _assert(not table_exists(output_table),
            "Sample: Output table ({output_table}) already exists.".format(**locals()))

    _assert(class_col,
            "Sample: Class column name is missing.".format(**locals()))
    _assert(columns_exist_in_table(source_table, [class_col]),
            ("""Sample: Class column ({class_col}) does not exist in""" +
             """ table ({source_table}).""").format(**locals()))

    _assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
            ("""Sample: Please ensure the source table ({0})""" +
             """ does not contain a column named {1}""").format(source_table, NEW_ID_COLUMN))

    _assert((not output_table_size) or (output_table_size > 0),
            "Sample: Invalid output table size ({output_table_size}).".format(
            **locals()))
# ------------------------------------------------------------------------------


def balance_sample_help(schema_madlib, message, **kwargs):
    """
    Help function for balance_sample

    Args:
        @param schema_madlib
        @param message: string, Help message string
        @param kwargs

    Returns:
        String. Help/usage information
    """
    if not message:
        help_string = """
-----------------------------------------------------------------------
                            SUMMARY
-----------------------------------------------------------------------
Given a table with varying set of records for each class label,
this function will create an output table with a varying types (by
default: uniform) of sampling distributions of each class label. It is
possible to use with or without replacement sampling methods, specify
different proportions of each class, multiple grouping columns and/or
output table size.

For more details on function usage:
    SELECT {schema_madlib}.balance_sample('usage');
            """
    elif message.lower() in ['usage', 'help', '?']:
        help_string = """

Given a table, stratified sampling returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.

----------------------------------------------------------------------------
                            USAGE
----------------------------------------------------------------------------

 SELECT {schema_madlib}.balance_sample(
    source_table      TEXT,     -- Input table name.
    output_table      TEXT,     -- Output table name.
    class_col         TEXT,     -- Name of column containing the class to be
                                -- balanced.
    class_size        TEXT,     -- (Default: NULL) Parameter to define the size
                                -- of the different class values.
    output_table_size INTEGER,  -- (Default: NULL) Desired size of the output
                                -- data set.
    grouping_cols     TEXT,     -- (Default: NULL) The columns columns that
                                -- defines the grouping.
    with_replacement  BOOLEAN   -- (Default: FALSE) The sampling method.
    keep_null         BOOLEAN   -- (Default: FALSE) Consider class levels with
                                    NULL values or not.

If class_size is NULL, the source table is uniformly sampled.

If output_table_size is NULL, the resulting output table size will depend on
the settings for the ‘class_size’ parameter. It is ignored if ‘class_size’
parameter is set to either ‘oversample’ or ‘undersample’.

If grouping_cols is NULL, the whole table is treated as a single group and
sampled accordingly.

If with_replacement is TRUE, each sample is independent (the same row may
be selected in the sample set more than once). Else (if with_replacement
is FALSE), a row can be selected at most once.
);

The output_table would contain the required number of samples, along with a
new column named __madlib_id__, that contain unique numbers for all
sampled rows.
"""
    else:
        help_string = "No such option. Use {schema_madlib}.balance_sample()"

    return help_string.format(schema_madlib=schema_madlib)


import unittest


class UtilitiesTestCase(unittest.TestCase):
    """
        Comment "import plpy" and replace plpy.error calls with appropriate
        Python Exceptions to successfully run the test cases
    """

    def setUp(self):
        self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}}
        self.level1a = 'a'
        self.level1a_cnt1 = 15
        self.level1a_cnt2 = 25
        self.level1a_cnt3 = 20

        self.sampling_strategy_str0 = ''
        self.sampling_strategy_str1 = 'uniform'
        self.sampling_strategy_str2 = 'oversample'
        self.sampling_strategy_str3 = 'undersample'
        self.user_specified_class_size0 = ''
        self.user_specified_class_size1 = {'a': 25, 'b': 25}
        self.user_specified_class_size2 = {'b': 25}
        self.user_specified_class_size3 = {'a': 30}
        self.output_table_size1 = None
        self.output_table_size2 = 60
        # self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100}

    def test__choose_strategy(self):
        self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25))
        self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25))
        self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25))

    def test__get_target_level_counts(self):
        # Test cases for user defined class level samples, without output table size
        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size1,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size2,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size3,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        # Test cases for user defined class level samples, with output table size
        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size1,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size2))
        self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size2,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size2))
        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str0,
                                                  self.user_specified_class_size3,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size2))
        # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any output table size
        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str1,
                                                  self.user_specified_class_size0,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str2,
                                                  self.user_specified_class_size0,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str3,
                                                  self.user_specified_class_size0,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size1))
        # Test cases for UNIFORM with output table size
        self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
                         _get_target_level_counts(self.sampling_strategy_str1,
                                                  self.user_specified_class_size0,
                                                  self.input_class_level_counts1,
                                                  self.output_table_size2))

    def test__get_sampling_strategy_specific_dict(self):
        # Test cases for getting sampling strategy specific counts
        target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}
        target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)}
        target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 'c': (25, NOSAMPLE)}
        self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
                         _get_sampling_strategy_counts(target_level_counts_1))
        self.assertEqual(({'b': 25}, {'a': 25}, {}),
                         _get_sampling_strategy_counts(target_level_counts_2))
        self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
                         _get_sampling_strategy_counts(target_level_counts_3))


if __name__ == '__main__':
    unittest.main()
