| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file EXCEPT in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # m4_changequote(`<!', `!>') |
| |
| import math |
| from collections import defaultdict |
| |
| if __name__ != "__main__": |
| import plpy |
| from utilities.control import MinWarning |
| from utilities.utilities import _assert |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import unique_string |
| from utilities.utilities import collate_plpy_result |
| from utilities.utilities import get_grouping_col_str |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import explicit_bool_to_text |
| from utilities.validate_args import get_cols |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import table_is_empty |
| else: |
| # Used only for Unit Testing |
| # FIXME: repeating a function from utilities that is needed by the unit test. |
| # This should be removed once a unittest framework in used for testing. |
| import random |
| import time |
| |
| def unique_string(desp='', **kwargs): |
| """ |
| Generate random remporary names for temp table and other names. |
| It has a SQL interface so both SQL and Python functions can call it. |
| """ |
| r1 = random.randint(1, 100000000) |
| r2 = int(time.time()) |
| r3 = int(time.time()) % random.randint(1, 100000000) |
| u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__" |
| return u_string |
| # ------------------------------------------------------------------------------ |
| |
| UNIFORM = 'uniform' |
| UNDERSAMPLE = 'undersample' |
| OVERSAMPLE = 'oversample' |
| NOSAMPLE = 'nosample' |
| |
| NEW_ID_COLUMN = '__madlib_id__' |
| NULL_IDENTIFIER = '__madlib_null_id__' |
| |
| |
| def _get_level_frequency_distribution(source_table, class_col, |
| grp_by_cols=None): |
| """ Count the number of rows for each class, partitioned by the grp_by_cols |
| |
| Returns a dict containing the number of rows associated with each class |
| level. Each class level count is converted to a string using ::text. |
| None is a valid key in this dict, capturing NULL value in the database. |
| """ |
| if grp_by_cols and grp_by_cols.lower() != 'null': |
| is_grouping = True |
| grp_by_cols_comma = grp_by_cols + ', ' |
| array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS group_values, " |
| else: |
| is_grouping = False |
| grp_by_cols_comma = array_grp_by_cols_comma = "" |
| |
| # In below query, the inner query groups the data using grp_by_cols + classes |
| # and obtains the count for each combination. The outer query then groups |
| # again by the grp_by_cols to collect the classes and counts in an array. |
| query_result = plpy.execute(""" |
| SELECT |
| -- For each group get the classes and their rows counts |
| {grp_identifier} as group_values, |
| array_agg(classes) as classes, |
| array_agg(class_count) as class_count |
| FROM( |
| -- for each group and class combination present in source table |
| -- get the count of rows for that combination |
| SELECT |
| {array_grp_by_cols_comma} |
| ({class_col})::TEXT AS classes, |
| count(*) AS class_count |
| FROM {source_table} |
| GROUP BY {grp_by_cols_comma} ({class_col}) |
| ) q |
| {meta_grp_by} |
| """.format(grp_identifier="group_values" if is_grouping else "NULL", |
| meta_grp_by="GROUP BY group_values" if is_grouping else "", |
| **locals())) |
| if (len(query_result) > 1) != is_grouping: |
| # if is_grouping then query_result should have more than 1 row |
| # if not is_grouping then query_result should have only 1 row |
| raise RuntimeError("Balance sample: Error during frequency level distribution") |
| |
| actual_grp_level_counts = {} |
| for each_row in query_result: |
| # group_values is a list for each row; convert it to a tuple to use as |
| # key in a dictionary |
| grp = tuple(each_row['group_values']) if is_grouping else None |
| grp_levels, grp_counts = each_row['classes'], each_row['class_count'] |
| actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts)) |
| return actual_grp_level_counts |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _validate_and_get_sampling_strategy(sampling_strategy_str, |
| output_table_size, |
| default=UNIFORM): |
| """ Returns the sampling strategy based on the class_sizes input param. |
| @param sampling_strategy_str The sampling strategy specified by the |
| user (class_sizes param) |
| @returns: |
| Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM. |
| """ |
| if not sampling_strategy_str: |
| sampling_strategy_str = default |
| else: |
| if len(sampling_strategy_str) < 3: |
| # Require at least 3 characters since UNIFORM and UNDERSAMPLE have |
| # common prefix substring |
| plpy.error("Sample: Invalid class_sizes parameter") |
| |
| supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE) |
| try: |
| # allow user to specify a prefix substring of |
| # supported strategies. |
| sampling_strategy_str = next(x for x in supported_strategies |
| if x.startswith(sampling_strategy_str.lower())) |
| except StopIteration: |
| # next() returns a StopIteration if no element found |
| plpy.error("Sample: Invalid class_sizes parameter: " |
| "{0}. Supported class_size parameters are ({1})". |
| format(sampling_strategy_str, |
| ','.join(sorted(supported_strategies)))) |
| |
| _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or |
| (sampling_strategy_str.find('=') > 0), |
| "Sample: Invalid class_sizes parameter: " |
| "{0}. Supported class_size parameters are ({1})". |
| format(sampling_strategy_str, |
| ','.join(sorted(supported_strategies)))) |
| |
| _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size), |
| "Sample: Cannot set output_table_size with oversampling.") |
| |
| _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size), |
| "Sample: Cannot set output_table_size with undersampling.") |
| |
| return sampling_strategy_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _choose_strategy(actual_count, desired_count): |
| """ Choose sampling strategy by comparing actual and desired sample counts |
| |
| @param actual_count: Actual number of samples for some level |
| @param desired_count: Desired number of sample for the level |
| @returns: |
| Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE) |
| """ |
| # OVERSAMPLE when the actual count is less than the desired count |
| # UNDERSAMPLE when the actual count is more than the desired count |
| |
| # If the actual count for a class level is the same as desired count, then |
| # we could potentially return the input rows as is. This, however, |
| # precludes the case of bootstrapping (i.e. returning same number of rows |
| # but after sampling with replacement). Hence, we treat the actual=desired |
| # as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it provides |
| # both 'with' and 'without' replacement (OVERSAMPLE is always with |
| # replacement and NOSAMPLE is always without replacement) |
| if actual_count < desired_count: |
| return OVERSAMPLE |
| else: |
| return UNDERSAMPLE |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_desired_target_level_counts(desired_level_counts, |
| actual_grp_level_counts, |
| output_table_size): |
| """ Return the target counts for each group and each class level in the group |
| |
| This function is specifically used when the user has provided the |
| targets for all (or a subset) of the levels. The strategy of either |
| under or oversampling for each class level is chosen based on the |
| desired number of counts for a level, and the actual number of counts |
| already present in the input table. This calculation is performed for |
| each group (or once if no grouping is used). |
| |
| @returns: dict: The key of the dictionary is the group value and the |
| value is another dictionary. The inner dictionary gives |
| the target count for each level present in the group. |
| """ |
| target_grp_level_counts = defaultdict(dict) |
| n_grp_values = len(actual_grp_level_counts) |
| for each_grp, actual_level_counts in actual_grp_level_counts.items(): |
| for each_level, desired_count in desired_level_counts.items(): |
| # actual_grp_level_counts are group specific, whereas |
| # desired_level_counts are evenly split among all groups |
| try: |
| per_grp_desired_count = math.ceil(float(desired_count) / n_grp_values) |
| sample_strategy = _choose_strategy(actual_level_counts[each_level], |
| per_grp_desired_count) |
| target_grp_level_counts[each_grp][each_level] = ( |
| per_grp_desired_count, sample_strategy) |
| except KeyError: |
| plpy.error("Balance sample: Desired class level ({0}) not present in the " |
| "data for each group.".format(each_level)) |
| |
| # desired levels could contain just a subset of all levels. For the remaining |
| # levels in actual_level_counts, compute the desired counts |
| remaining_levels = (set(actual_level_counts.keys()) - |
| set(desired_level_counts.keys())) |
| |
| # if 'output_table_size' = NULL, remaining level counts remain as is |
| # if 'output_table_size' = <Integer>, divide remaining count |
| # uniformly among reamining levels |
| if output_table_size: |
| # output_table_size is for the whole table and should be split evenly |
| # between the groups |
| remaining_rows = math.ceil(float(output_table_size - |
| sum(desired_level_counts.values())) / |
| n_grp_values) |
| if remaining_rows > 0: |
| # Uniformly distribute the remaining class levels |
| rows_per_level = math.ceil(float(remaining_rows) / |
| len(remaining_levels)) |
| for each_level in remaining_levels: |
| sample_strategy = _choose_strategy( |
| actual_level_counts[each_level], rows_per_level) |
| target_grp_level_counts[each_grp][each_level] = ( |
| rows_per_level, sample_strategy) |
| else: |
| # When output_table_size is unspecified, rows from the input table |
| # are sampled as is for remaining class levels. This is called as the |
| # NOSAMPLE strategy. |
| for each_level in remaining_levels: |
| target_grp_level_counts[each_grp][each_level] = ( |
| actual_level_counts[each_level], NOSAMPLE) |
| return target_grp_level_counts |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_supported_target_level_counts(sampling_strategy_str, |
| actual_grp_level_counts, |
| output_table_size): |
| """ Returns the target level counts for all levels when the class_size param |
| is one of [uniform, undersample, oversample]. The strategy of |
| (under)oversampling for a specific class level is chosen based on the |
| computed number of counts for a level, and the actual number of counts |
| already present in the input table. |
| |
| """ |
| target_grp_level_counts = defaultdict(dict) |
| n_grp_values = len(actual_grp_level_counts) |
| for each_grp, actual_level_counts in actual_grp_level_counts.items(): |
| if sampling_strategy_str == UNIFORM: |
| # UNIFORM: Ensure all level counts are same |
| if output_table_size: |
| # Ignore actual counts for computing target sizes |
| # if output_table_size is specified |
| total_ = float(output_table_size) / n_grp_values |
| else: |
| total_ = sum(actual_level_counts.values()) |
| target_size_per_level = math.ceil(float(total_) / |
| len(actual_level_counts)) |
| else: |
| # UNDERSAMPLE: Ensure all level counts are same as the minimum count |
| # OVERSAMPLE: Ensure all level counts are same as the maximum count |
| if sampling_strategy_str == UNDERSAMPLE: |
| target_size_per_level = min(actual_level_counts.values()) |
| elif sampling_strategy_str == OVERSAMPLE: |
| target_size_per_level = max(actual_level_counts.values()) |
| else: |
| raise RuntimeError("Balance sample: Invalid " |
| "sampling_strategy_str encountered") |
| |
| for each_level, actual_count in actual_level_counts.items(): |
| sample_strategy = _choose_strategy(actual_count, target_size_per_level) |
| target_grp_level_counts[each_grp][each_level] = ( |
| target_size_per_level, sample_strategy) |
| return target_grp_level_counts |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_target_level_counts(sampling_strategy_str, desired_level_counts, |
| actual_grp_level_counts, output_table_size): |
| """ |
| @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None]. |
| This is 'None' only if this is user-defined, i.e., |
| a comma separated list of class levels and number |
| of rows desired pairs. |
| @param desired_level_counts: Dict that is defined only when the previous arg |
| sampling_strategy_str is None. This dict would |
| then contain the class levels and the |
| corresponding number of rows specified by |
| the user. |
| @param actual_grp_level_counts: Dictionary that provides for each group the |
| the count of number of rows for each class |
| present in the group |
| @param output_table_size: Size of the desired output table (NULL or Integer) |
| |
| @returns: |
| Dict. Number of samples to be drawn, and the sampling strategy to be |
| used for each class level. |
| """ |
| |
| if not sampling_strategy_str: |
| # This case implies user has provided a desired count for one or more |
| # levels. Counts for rest of the levels depend on 'output_table_size'. |
| target_level_counts = _get_desired_target_level_counts( |
| desired_level_counts, actual_grp_level_counts, output_table_size) |
| else: |
| # This case imples the user has chosen one of |
| # [uniform, undersample, oversample] for the class_size parameter. |
| target_level_counts = _get_supported_target_level_counts( |
| sampling_strategy_str, actual_grp_level_counts, output_table_size) |
| return target_level_counts |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_sampling_strategy_counts(target_class_sizes): |
| """ Return three dicts, one each for undersampling, oversampling, and |
| nosampling. The dict contains the number of samples to be drawn for |
| each class level. |
| """ |
| undersample_level_counts = {} |
| oversample_level_counts = {} |
| nosample_level_counts = {} |
| for level, (count, strategy) in target_class_sizes.items(): |
| if strategy == UNDERSAMPLE: |
| undersample_level_counts[level] = count |
| elif strategy == OVERSAMPLE: |
| oversample_level_counts[level] = count |
| else: |
| nosample_level_counts[level] = count |
| return (undersample_level_counts, oversample_level_counts, nosample_level_counts) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _get_nosample_subquery(source_table, class_col, nosample_levels, |
| grp_dict=None): |
| """ Return the subquery for fetching all rows as is from the input table |
| for specific class levels. |
| """ |
| if not nosample_levels: |
| return '' |
| nosample_level_str = ','.join(["'{0}'".format(level) |
| for level in nosample_levels if level]) |
| if grp_dict: |
| grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v) |
| for k, v in grp_dict.items()) |
| else: |
| grp_filter = '' |
| return """ |
| SELECT * |
| FROM {source_table} |
| WHERE ({class_col} in ({nosample_level_str}) OR |
| {class_col} IS NULL) |
| {grp_filter} |
| """.format(**locals()) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _get_without_replacement_subquery(schema_madlib, source_table, class_col, |
| actual_level_counts, desired_level_counts, |
| grp_dict=None): |
| """ Return the subquery for sampling without replacement for specific |
| class levels. |
| """ |
| if not desired_level_counts: |
| return '' |
| class_col_tmp = unique_string(desp='class_col') |
| row_number_col = unique_string(desp='row_number') |
| desired_count_col = unique_string(desp='desired_count') |
| source_table_columns = ','.join(get_cols(source_table)) |
| null_value_string = "'{0}'".format(NULL_IDENTIFIER) |
| desired_level_count_pairs = (','.join("({0}, {1})". |
| format("'{0}'::text".format(k) if k else null_value_string, v) |
| for k, v in desired_level_counts.items())) |
| desired_level_counts_str = "VALUES " + desired_level_count_pairs |
| # Subquery q2 is used to figure out the number of rows to select for each |
| # class level. That is used with row_number to order the input rows randomly, |
| # and trim the results to the desired number of rows for each class level. |
| # q1: |
| # The FROM clause contains information that can be used to figure out the |
| # number of rows to generate for a class level. The tuple basically |
| # contains: (class_level, the_desired_number_of_rows) |
| if grp_dict: |
| grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v) |
| for k, v in grp_dict.items()) |
| else: |
| grp_filter = '' |
| subquery = """ |
| SELECT {source_table_columns} |
| FROM ( |
| SELECT {source_table_columns}, |
| row_number() OVER (PARTITION BY {class_col} ORDER BY random()) AS {row_number_col}, |
| {desired_count_col} |
| FROM ( |
| SELECT {source_table_columns}, |
| {desired_count_col} |
| FROM |
| {source_table} s, |
| ({desired_level_counts_str}) |
| q1({class_col_tmp}, {desired_count_col}) |
| WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_identifier}') |
| {grp_filter} |
| ) q2 |
| ) q3 |
| WHERE {row_number_col} <= {desired_count_col} |
| """.format(null_identifier=NULL_IDENTIFIER, **locals()) |
| return subquery |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _get_with_replacement_subquery(schema_madlib, source_table, class_col, |
| actual_level_counts, desired_level_counts, |
| grp_dict=None): |
| """ Return the query for sampling with replacement for specific class |
| levels. Always used for oversampling since oversampling will always need |
| to use replacement. Used for under sampling only if with_replacement |
| flag is set to TRUE. |
| """ |
| if not desired_level_counts: |
| return '' |
| |
| source_table_columns = ','.join(get_cols(source_table)) |
| class_col_tmp = unique_string(desp='class_col_with_rep') |
| desired_count_col = unique_string(desp='desired_count_with_rep') |
| actual_count_col = unique_string(desp='actual_count') |
| q1_row_no = unique_string(desp='q1_row') |
| q2_row_no = unique_string(desp='q2_row') |
| |
| null_value_string = "'{0}'".format(NULL_IDENTIFIER) |
| desired_and_actual_count = (','.join("({0}, {1}, {2})". |
| format("'{0}'::text".format(k) if k else null_value_string, |
| v, actual_level_counts[k]) |
| for k, v in desired_level_counts.items())) |
| desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count |
| |
| # q1 and q2 are two sub queries we create to generate the required number of |
| # rows per class level. |
| # q1: |
| # The FROM clause contains information that can be used to figure out the |
| # number of rows to generate for a class level. The tuple basically |
| # contains: |
| # (class_level, the_desired_number_of_rows, actual_rows_for_level_in_input_table) |
| # The SELECT clause uses generate series to duplicate a row {q1_row_no} |
| # of times, which is a value between 1 and actual_rows_for_level_in_input_table |
| # q2: |
| # Replicates the source_table with row IDs starting from 1 through |
| # actual_rows_for_level_in_input_table. |
| # |
| # The WHERE clause is used to join the two subqueries to obtain the result. |
| if grp_dict: |
| grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v) |
| for k, v in grp_dict.items()) |
| else: |
| grp_filter = '' |
| subquery = """ |
| SELECT {source_table_columns} |
| FROM |
| ( |
| SELECT |
| {class_col_tmp}, |
| generate_series(1, {desired_count_col}::int) AS _i, |
| ((random()*({actual_count_col}-1)+1)::int) AS {q1_row_no} |
| FROM |
| ({desired_and_actual_level_count_str}) |
| q({class_col_tmp}, {desired_count_col}, {actual_count_col}) |
| ) q1, |
| ( |
| SELECT |
| *, |
| row_number() OVER(PARTITION BY {class_col}) AS {q2_row_no} |
| FROM |
| {source_table} |
| {grp_filter} |
| ) q2 |
| WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') AND |
| q1.{q1_row_no} = q2.{q2_row_no} |
| """.format(null_level_val=NULL_IDENTIFIER, **locals()) |
| return subquery |
| # ------------------------------------------------------------------------------ |
| |
| |
| def balance_sample(schema_madlib, source_table, output_table, class_col, |
| class_sizes, output_table_size, grouping_cols, |
| with_replacement, keep_null, **kwargs): |
| """ |
| Balance sampling function |
| Args: |
| @param schema_madlib Schema that MADlib is installed on. |
| @param source_table Input table name. |
| @param output_table Output table name. |
| @param class_col Name of the column containing the class to be |
| balanced. |
| @param class_sizes Parameter to define the size of the different |
| class values. |
| @param output_table_size Desired size of the output data set. |
| @param grouping_cols The columns that define the grouping. |
| @param with_replacement The sampling method. |
| @param keep_null Flag to include rows with class level values |
| NULL. Default is False. |
| |
| """ |
| with MinWarning("warning"): |
| |
| desired_sample_per_class = unique_string(desp='desired_sample_per_class') |
| desired_counts = unique_string(desp='desired_counts') |
| |
| # set all default values |
| if not class_sizes: |
| class_sizes = UNIFORM |
| if not with_replacement: |
| with_replacement = False |
| keep_null = False if not keep_null else True |
| if class_sizes: |
| class_sizes = class_sizes.strip() |
| |
| _validate_strs(source_table, output_table, class_col, |
| output_table_size, grouping_cols) |
| |
| # If keep_null=False, create a view of the input table ignoring NULL |
| # values for class levels. |
| if keep_null: |
| new_source_table = source_table |
| else: |
| new_source_table = unique_string(desp='source_table') |
| plpy.execute(""" |
| CREATE VIEW {new_source_table} AS |
| SELECT * FROM {source_table} |
| WHERE {class_col} IS NOT NULL |
| """.format(**locals())) |
| # class_sizes can be of two forms: |
| # 1. A string describing sampling strategy (as described in |
| # _validate_and_get_sampling_strategy). |
| # In this case, 'sampling_strategy_str' is set to one of |
| # [UNIFORM, UNDERSAMPLE, OVERSAMPLE] |
| # 2. Class sizes for all (or a subset) of the class levels |
| # In this case, sampling_strategy_str = None and parsed_class_sizes |
| # is used for the sampling. |
| parsed_class_sizes = extract_keyvalue_params(class_sizes, |
| allow_duplicates=False, |
| lower_case_names=False) |
| |
| # GREENPLUM 4.3.X does not support bool-to-text cast which is relied |
| # upon for class_col in multiple queries. The explicit_bool_to_text |
| # function wraps the class_col with a MADlib function that provides the |
| # cast just for those platforms that don't provide the cast. For |
| # platforms that provide the cast, class_col is unchanged below. |
| class_col = explicit_bool_to_text(source_table, |
| [class_col], |
| schema_madlib)[0] |
| distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ". |
| format(class_col, |
| source_table)) |
| distinct_levels = collate_plpy_result(plpy.execute(distinct_sql))['levels'] |
| if not parsed_class_sizes: |
| sampling_strategy_str = _validate_and_get_sampling_strategy( |
| class_sizes, output_table_size) |
| else: |
| sampling_strategy_str = None |
| try: |
| for each_level, each_class_size in parsed_class_sizes.items(): |
| _assert(each_level in distinct_levels, |
| "Sample: Invalid class value specified ({0})". |
| format(each_level)) |
| each_class_size = int(each_class_size) |
| _assert(each_class_size >= 1, |
| "Sample: Class size has to be greater than zero") |
| parsed_class_sizes[each_level] = each_class_size |
| except ValueError: |
| plpy.error("Sample: Invalid value for class_sizes ({0})". |
| format(class_sizes)) |
| |
| # Get the number of rows to be sampled for each class level, based on |
| # the input table, class_sizes, and output_table_size params. This also |
| # includes info about the resulting sampling strategy, i.e., one of |
| # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level. |
| grp_col_str, grp_cols = get_grouping_col_str( |
| schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col], |
| source_table, grouping_cols) |
| actual_grp_level_counts = _get_level_frequency_distribution( |
| new_source_table, class_col, grp_col_str) |
| target_grp_class_sizes = _get_target_level_counts( |
| sampling_strategy_str, parsed_class_sizes, |
| actual_grp_level_counts, output_table_size) |
| |
| is_output_created = False |
| grp_cols_list = grp_col_str.split(',') if grp_cols else [] |
| |
| # for each group |
| for grp_vals, actual_level_counts in actual_grp_level_counts.items(): |
| target_class_sizes = target_grp_class_sizes[grp_vals] |
| |
| (undersample_level_counts, |
| oversample_level_counts, |
| nosample_level_counts) = _get_sampling_strategy_counts(target_class_sizes) |
| |
| # grp_dict represents each grouping column and its value |
| # for current iteration |
| grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None |
| |
| # Get subqueries for each sampling strategy and union together in |
| # one big query. |
| |
| # NOSAMPLE are levels that are to be retained in output without sampling |
| nosample_subquery = _get_nosample_subquery( |
| new_source_table, class_col, nosample_level_counts.keys(), grp_dict) |
| # OVERSAMPLE are levels that to be sampled more than existing count |
| # (always with replacement) |
| oversample_subquery = _get_with_replacement_subquery( |
| schema_madlib, new_source_table, class_col, |
| actual_level_counts, oversample_level_counts, grp_dict) |
| # UNDERSAMPLE are levels that are to be sampled to less than |
| # existing count. Undersampling supports both with and without |
| # replacement. |
| if with_replacement: |
| undersample_subquery = _get_with_replacement_subquery( |
| schema_madlib, new_source_table, class_col, |
| actual_level_counts, undersample_level_counts, grp_dict) |
| else: |
| undersample_subquery = _get_without_replacement_subquery( |
| schema_madlib, new_source_table, class_col, |
| actual_level_counts, undersample_level_counts, grp_dict) |
| |
| # Merge the three subqueries using a UNION ALL clause. |
| sampling_queries = (undersample_subquery, oversample_subquery, nosample_subquery) |
| union_all_subquery = ' UNION ALL '.join( |
| ['({0})'.format(subquery) |
| for subquery in sampling_queries if subquery]) |
| |
| # Populate the output table |
| if is_output_created: |
| table_header = "INSERT INTO {0}".format(output_table) |
| else: |
| table_header = "CREATE TABLE {0} AS ".format(output_table) |
| is_output_created = True |
| final_query = """ |
| {table_header} |
| SELECT row_number() OVER() AS {id_col_name}, * |
| FROM ( |
| {union_all_subquery} |
| ) union_query |
| """.format(id_col_name=NEW_ID_COLUMN, **locals()) |
| plpy.execute(final_query) |
| # end of grouping loop |
| |
| if not keep_null: |
| plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _validate_strs(source_table, output_table, class_col, |
| output_table_size, grouping_cols): |
| _assert(source_table and table_exists(source_table), |
| "Sample: Source table ({source_table}) does not exist.".format(**locals())) |
| _assert(not table_is_empty(source_table), |
| "Sample: Source table ({source_table}) is empty.".format(**locals())) |
| |
| _assert(output_table, |
| "Sample: Output table name is missing.".format(**locals())) |
| _assert(not table_exists(output_table), |
| "Sample: Output table ({output_table}) already exists.".format(**locals())) |
| |
| _assert(class_col, |
| "Sample: Class column name is missing.".format(**locals())) |
| _assert(columns_exist_in_table(source_table, [class_col]), |
| ("""Sample: Class column ({class_col}) does not exist in""" + |
| """ table ({source_table}).""").format(**locals())) |
| |
| _assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]), |
| ("""Sample: Please ensure the source table ({0})""" + |
| """ does not contain a column named {1}""").format(source_table, NEW_ID_COLUMN)) |
| |
| _assert((not output_table_size) or (output_table_size > 0), |
| "Sample: Invalid output table size ({output_table_size}).".format( |
| **locals())) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def balance_sample_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for balance_sample |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if not message: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| SUMMARY |
| ----------------------------------------------------------------------- |
| Given a table with varying set of records for each class label, |
| this function will create an output table with a varying types (by |
| default: uniform) of sampling distributions of each class label. It is |
| possible to use with or without replacement sampling methods, specify |
| different proportions of each class, multiple grouping columns and/or |
| output table size. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.balance_sample('usage'); |
| """ |
| elif message.lower() in ['usage', 'help', '?']: |
| help_string = """ |
| |
| Given a table, stratified sampling returns a proportion of records for |
| each group (strata). It is possible to use with or without replacement |
| sampling methods, specify a set of target columns, and assume the |
| whole table is a single strata. |
| |
| ---------------------------------------------------------------------------- |
| USAGE |
| ---------------------------------------------------------------------------- |
| |
| SELECT {schema_madlib}.balance_sample( |
| source_table TEXT, -- Input table name. |
| output_table TEXT, -- Output table name. |
| class_col TEXT, -- Name of column containing the class to be |
| -- balanced. |
| class_size TEXT, -- (Default: NULL) Parameter to define the size |
| -- of the different class values. |
| output_table_size INTEGER, -- (Default: NULL) Desired size of the output |
| -- data set. |
| grouping_cols TEXT, -- (Default: NULL) The columns columns that |
| -- defines the grouping. |
| with_replacement BOOLEAN -- (Default: FALSE) The sampling method. |
| keep_null BOOLEAN -- (Default: FALSE) Consider class levels with |
| NULL values or not. |
| |
| If class_size is NULL, the source table is uniformly sampled. |
| |
| If output_table_size is NULL, the resulting output table size will depend on |
| the settings for the ‘class_size’ parameter. It is ignored if ‘class_size’ |
| parameter is set to either ‘oversample’ or ‘undersample’. |
| |
| If grouping_cols is NULL, the whole table is treated as a single group and |
| sampled accordingly. |
| |
| If with_replacement is TRUE, each sample is independent (the same row may |
| be selected in the sample set more than once). Else (if with_replacement |
| is FALSE), a row can be selected at most once. |
| ); |
| |
| The output_table would contain the required number of samples, along with a |
| new column named __madlib_id__, that contain unique numbers for all |
| sampled rows. |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.balance_sample()" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| |
| |
| import unittest |
| |
| |
| class UtilitiesTestCase(unittest.TestCase): |
| """ |
| Comment "import plpy" and replace plpy.error calls with appropriate |
| Python Exceptions to successfully run the test cases |
| """ |
| |
| def setUp(self): |
| self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}} |
| self.level1a = 'a' |
| self.level1a_cnt1 = 15 |
| self.level1a_cnt2 = 25 |
| self.level1a_cnt3 = 20 |
| |
| self.sampling_strategy_str0 = '' |
| self.sampling_strategy_str1 = 'uniform' |
| self.sampling_strategy_str2 = 'oversample' |
| self.sampling_strategy_str3 = 'undersample' |
| self.user_specified_class_size0 = '' |
| self.user_specified_class_size1 = {'a': 25, 'b': 25} |
| self.user_specified_class_size2 = {'b': 25} |
| self.user_specified_class_size3 = {'a': 30} |
| self.output_table_size1 = None |
| self.output_table_size2 = 60 |
| # self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100} |
| |
| def test__choose_strategy(self): |
| self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25)) |
| self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25)) |
| self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25)) |
| |
| def test__get_target_level_counts(self): |
| # Test cases for user defined class level samples, without output table size |
| self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size1, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size2, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size3, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| # Test cases for user defined class level samples, with output table size |
| self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size1, |
| self.input_class_level_counts1, |
| self.output_table_size2)) |
| self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size2, |
| self.input_class_level_counts1, |
| self.output_table_size2)) |
| self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str0, |
| self.user_specified_class_size3, |
| self.input_class_level_counts1, |
| self.output_table_size2)) |
| # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any output table size |
| self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str1, |
| self.user_specified_class_size0, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str2, |
| self.user_specified_class_size0, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str3, |
| self.user_specified_class_size0, |
| self.input_class_level_counts1, |
| self.output_table_size1)) |
| # Test cases for UNIFORM with output table size |
| self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}}, |
| _get_target_level_counts(self.sampling_strategy_str1, |
| self.user_specified_class_size0, |
| self.input_class_level_counts1, |
| self.output_table_size2)) |
| |
| def test__get_sampling_strategy_specific_dict(self): |
| # Test cases for getting sampling strategy specific counts |
| target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)} |
| target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)} |
| target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 'c': (25, NOSAMPLE)} |
| self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}), |
| _get_sampling_strategy_counts(target_level_counts_1)) |
| self.assertEqual(({'b': 25}, {'a': 25}, {}), |
| _get_sampling_strategy_counts(target_level_counts_2)) |
| self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}), |
| _get_sampling_strategy_counts(target_level_counts_3)) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |