# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file EXCEPT in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# m4_changequote(`<!', `!>')
import math
from collections import defaultdict
if __name__ != "__main__":
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import collate_plpy_result
from utilities.utilities import get_grouping_col_str
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import explicit_bool_to_text
from utilities.validate_args import get_cols
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
# Used only for Unit Testing
# FIXME: repeating a function from utilities that is needed by the unit test.
# This should be removed once a unittest framework in used for testing.
import random
import time
def unique_string(desp='', **kwargs):
Generate random remporary names for temp table and other names.
It has a SQL interface so both SQL and Python functions can call it.
r1 = random.randint(1, 100000000)
r2 = int(time.time())
r3 = int(time.time()) % random.randint(1, 100000000)
u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
return u_string
# ------------------------------------------------------------------------------
UNIFORM = 'uniform'
UNDERSAMPLE = 'undersample'
OVERSAMPLE = 'oversample'
NOSAMPLE = 'nosample'
NEW_ID_COLUMN = '__madlib_id__'
NULL_IDENTIFIER = '__madlib_null_id__'
def _get_level_frequency_distribution(source_table, class_col,
""" Count the number of rows for each class, partitioned by the grp_by_cols
Returns a dict containing the number of rows associated with each class
level. Each class level count is converted to a string using ::text.
None is a valid key in this dict, capturing NULL value in the database.
if grp_by_cols and grp_by_cols.lower() != 'null':
is_grouping = True
grp_by_cols_comma = grp_by_cols + ', '
array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS group_values, "
is_grouping = False
grp_by_cols_comma = array_grp_by_cols_comma = ""
# In below query, the inner query groups the data using grp_by_cols + classes
# and obtains the count for each combination. The outer query then groups
# again by the grp_by_cols to collect the classes and counts in an array.
query_result = plpy.execute("""
-- For each group get the classes and their rows counts
{grp_identifier} as group_values,
array_agg(classes) as classes,
array_agg(class_count) as class_count
-- for each group and class combination present in source table
-- get the count of rows for that combination
({class_col})::TEXT AS classes,
count(*) AS class_count
FROM {source_table}
GROUP BY {grp_by_cols_comma} ({class_col})
) q
""".format(grp_identifier="group_values" if is_grouping else "NULL",
meta_grp_by="GROUP BY group_values" if is_grouping else "",
if (len(query_result) > 1) != is_grouping:
# if is_grouping then query_result should have more than 1 row
# if not is_grouping then query_result should have only 1 row
raise RuntimeError("Balance sample: Error during frequency level distribution")
actual_grp_level_counts = {}
for each_row in query_result:
# group_values is a list for each row; convert it to a tuple to use as
# key in a dictionary
grp = tuple(each_row['group_values']) if is_grouping else None
grp_levels, grp_counts = each_row['classes'], each_row['class_count']
actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts))
return actual_grp_level_counts
# ------------------------------------------------------------------------------
def _validate_and_get_sampling_strategy(sampling_strategy_str,
""" Returns the sampling strategy based on the class_sizes input param.
@param sampling_strategy_str The sampling strategy specified by the
user (class_sizes param)
if not sampling_strategy_str:
sampling_strategy_str = default
if len(sampling_strategy_str) < 3:
# Require at least 3 characters since UNIFORM and UNDERSAMPLE have
# common prefix substring
plpy.error("Sample: Invalid class_sizes parameter")
supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE)
# allow user to specify a prefix substring of
# supported strategies.
sampling_strategy_str = next(x for x in supported_strategies
if x.startswith(sampling_strategy_str.lower()))
except StopIteration:
# next() returns a StopIteration if no element found
plpy.error("Sample: Invalid class_sizes parameter: "
"{0}. Supported class_size parameters are ({1})".
_assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
(sampling_strategy_str.find('=') > 0),
"Sample: Invalid class_sizes parameter: "
"{0}. Supported class_size parameters are ({1})".
_assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
"Sample: Cannot set output_table_size with oversampling.")
_assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
"Sample: Cannot set output_table_size with undersampling.")
return sampling_strategy_str
# ------------------------------------------------------------------------------
def _choose_strategy(actual_count, desired_count):
""" Choose sampling strategy by comparing actual and desired sample counts
@param actual_count: Actual number of samples for some level
@param desired_count: Desired number of sample for the level
Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
# OVERSAMPLE when the actual count is less than the desired count
# UNDERSAMPLE when the actual count is more than the desired count
# If the actual count for a class level is the same as desired count, then
# we could potentially return the input rows as is. This, however,
# precludes the case of bootstrapping (i.e. returning same number of rows
# but after sampling with replacement). Hence, we treat the actual=desired
# as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it provides
# both 'with' and 'without' replacement (OVERSAMPLE is always with
# replacement and NOSAMPLE is always without replacement)
if actual_count < desired_count:
# -------------------------------------------------------------------------
def _get_desired_target_level_counts(desired_level_counts,
""" Return the target counts for each group and each class level in the group
This function is specifically used when the user has provided the
targets for all (or a subset) of the levels. The strategy of either
under or oversampling for each class level is chosen based on the
desired number of counts for a level, and the actual number of counts
already present in the input table. This calculation is performed for
each group (or once if no grouping is used).
@returns: dict: The key of the dictionary is the group value and the
value is another dictionary. The inner dictionary gives
the target count for each level present in the group.
target_grp_level_counts = defaultdict(dict)
n_grp_values = len(actual_grp_level_counts)
for each_grp, actual_level_counts in actual_grp_level_counts.items():
for each_level, desired_count in desired_level_counts.items():
# actual_grp_level_counts are group specific, whereas
# desired_level_counts are evenly split among all groups
per_grp_desired_count = math.ceil(float(desired_count) / n_grp_values)
sample_strategy = _choose_strategy(actual_level_counts[each_level],
target_grp_level_counts[each_grp][each_level] = (
per_grp_desired_count, sample_strategy)
except KeyError:
plpy.error("Balance sample: Desired class level ({0}) not present in the "
"data for each group.".format(each_level))
# desired levels could contain just a subset of all levels. For the remaining
# levels in actual_level_counts, compute the desired counts
remaining_levels = (set(actual_level_counts.keys()) -
# if 'output_table_size' = NULL, remaining level counts remain as is
# if 'output_table_size' = <Integer>, divide remaining count
# uniformly among reamining levels
if output_table_size:
# output_table_size is for the whole table and should be split evenly
# between the groups
remaining_rows = math.ceil(float(output_table_size -
sum(desired_level_counts.values())) /
if remaining_rows > 0:
# Uniformly distribute the remaining class levels
rows_per_level = math.ceil(float(remaining_rows) /
for each_level in remaining_levels:
sample_strategy = _choose_strategy(
actual_level_counts[each_level], rows_per_level)
target_grp_level_counts[each_grp][each_level] = (
rows_per_level, sample_strategy)
# When output_table_size is unspecified, rows from the input table
# are sampled as is for remaining class levels. This is called as the
# NOSAMPLE strategy.
for each_level in remaining_levels:
target_grp_level_counts[each_grp][each_level] = (
actual_level_counts[each_level], NOSAMPLE)
return target_grp_level_counts
# -------------------------------------------------------------------------
def _get_supported_target_level_counts(sampling_strategy_str,
""" Returns the target level counts for all levels when the class_size param
is one of [uniform, undersample, oversample]. The strategy of
(under)oversampling for a specific class level is chosen based on the
computed number of counts for a level, and the actual number of counts
already present in the input table.
target_grp_level_counts = defaultdict(dict)
n_grp_values = len(actual_grp_level_counts)
for each_grp, actual_level_counts in actual_grp_level_counts.items():
if sampling_strategy_str == UNIFORM:
# UNIFORM: Ensure all level counts are same
if output_table_size:
# Ignore actual counts for computing target sizes
# if output_table_size is specified
total_ = float(output_table_size) / n_grp_values
total_ = sum(actual_level_counts.values())
target_size_per_level = math.ceil(float(total_) /
# UNDERSAMPLE: Ensure all level counts are same as the minimum count
# OVERSAMPLE: Ensure all level counts are same as the maximum count
if sampling_strategy_str == UNDERSAMPLE:
target_size_per_level = min(actual_level_counts.values())
elif sampling_strategy_str == OVERSAMPLE:
target_size_per_level = max(actual_level_counts.values())
raise RuntimeError("Balance sample: Invalid "
"sampling_strategy_str encountered")
for each_level, actual_count in actual_level_counts.items():
sample_strategy = _choose_strategy(actual_count, target_size_per_level)
target_grp_level_counts[each_grp][each_level] = (
target_size_per_level, sample_strategy)
return target_grp_level_counts
# -------------------------------------------------------------------------
def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
actual_grp_level_counts, output_table_size):
@param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
This is 'None' only if this is user-defined, i.e.,
a comma separated list of class levels and number
of rows desired pairs.
@param desired_level_counts: Dict that is defined only when the previous arg
sampling_strategy_str is None. This dict would
then contain the class levels and the
corresponding number of rows specified by
the user.
@param actual_grp_level_counts: Dictionary that provides for each group the
the count of number of rows for each class
present in the group
@param output_table_size: Size of the desired output table (NULL or Integer)
Dict. Number of samples to be drawn, and the sampling strategy to be
used for each class level.
if not sampling_strategy_str:
# This case implies user has provided a desired count for one or more
# levels. Counts for rest of the levels depend on 'output_table_size'.
target_level_counts = _get_desired_target_level_counts(
desired_level_counts, actual_grp_level_counts, output_table_size)
# This case imples the user has chosen one of
# [uniform, undersample, oversample] for the class_size parameter.
target_level_counts = _get_supported_target_level_counts(
sampling_strategy_str, actual_grp_level_counts, output_table_size)
return target_level_counts
# -------------------------------------------------------------------------
def _get_sampling_strategy_counts(target_class_sizes):
""" Return three dicts, one each for undersampling, oversampling, and
nosampling. The dict contains the number of samples to be drawn for
each class level.
undersample_level_counts = {}
oversample_level_counts = {}
nosample_level_counts = {}
for level, (count, strategy) in target_class_sizes.items():
if strategy == UNDERSAMPLE:
undersample_level_counts[level] = count
elif strategy == OVERSAMPLE:
oversample_level_counts[level] = count
nosample_level_counts[level] = count
return (undersample_level_counts, oversample_level_counts, nosample_level_counts)
# ------------------------------------------------------------------------------
def _get_nosample_subquery(source_table, class_col, nosample_levels,
""" Return the subquery for fetching all rows as is from the input table
for specific class levels.
if not nosample_levels:
return ''
nosample_level_str = ','.join(["'{0}'".format(level)
for level in nosample_levels if level])
if grp_dict:
grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
grp_filter = ''
return """
FROM {source_table}
WHERE ({class_col} in ({nosample_level_str}) OR
{class_col} IS NULL)
# ------------------------------------------------------------------------------
def _get_without_replacement_subquery(schema_madlib, source_table, class_col,
actual_level_counts, desired_level_counts,
""" Return the subquery for sampling without replacement for specific
class levels.
if not desired_level_counts:
return ''
class_col_tmp = unique_string(desp='class_col')
row_number_col = unique_string(desp='row_number')
desired_count_col = unique_string(desp='desired_count')
source_table_columns = ','.join(get_cols(source_table))
null_value_string = "'{0}'".format(NULL_IDENTIFIER)
desired_level_count_pairs = (','.join("({0}, {1})".
format("'{0}'::text".format(k) if k else null_value_string, v)
for k, v in desired_level_counts.items()))
desired_level_counts_str = "VALUES " + desired_level_count_pairs
# Subquery q2 is used to figure out the number of rows to select for each
# class level. That is used with row_number to order the input rows randomly,
# and trim the results to the desired number of rows for each class level.
# q1:
# The FROM clause contains information that can be used to figure out the
# number of rows to generate for a class level. The tuple basically
# contains: (class_level, the_desired_number_of_rows)
if grp_dict:
grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
grp_filter = ''
subquery = """
SELECT {source_table_columns}
SELECT {source_table_columns},
row_number() OVER (PARTITION BY {class_col} ORDER BY random()) AS {row_number_col},
SELECT {source_table_columns},
{source_table} s,
q1({class_col_tmp}, {desired_count_col})
WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_identifier}')
) q2
) q3
WHERE {row_number_col} <= {desired_count_col}
""".format(null_identifier=NULL_IDENTIFIER, **locals())
return subquery
# ------------------------------------------------------------------------------
def _get_with_replacement_subquery(schema_madlib, source_table, class_col,
actual_level_counts, desired_level_counts,
""" Return the query for sampling with replacement for specific class
levels. Always used for oversampling since oversampling will always need
to use replacement. Used for under sampling only if with_replacement
flag is set to TRUE.
if not desired_level_counts:
return ''
source_table_columns = ','.join(get_cols(source_table))
class_col_tmp = unique_string(desp='class_col_with_rep')
desired_count_col = unique_string(desp='desired_count_with_rep')
actual_count_col = unique_string(desp='actual_count')
q1_row_no = unique_string(desp='q1_row')
q2_row_no = unique_string(desp='q2_row')
null_value_string = "'{0}'".format(NULL_IDENTIFIER)
desired_and_actual_count = (','.join("({0}, {1}, {2})".
format("'{0}'::text".format(k) if k else null_value_string,
v, actual_level_counts[k])
for k, v in desired_level_counts.items()))
desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count
# q1 and q2 are two sub queries we create to generate the required number of
# rows per class level.
# q1:
# The FROM clause contains information that can be used to figure out the
# number of rows to generate for a class level. The tuple basically
# contains:
# (class_level, the_desired_number_of_rows, actual_rows_for_level_in_input_table)
# The SELECT clause uses generate series to duplicate a row {q1_row_no}
# of times, which is a value between 1 and actual_rows_for_level_in_input_table
# q2:
# Replicates the source_table with row IDs starting from 1 through
# actual_rows_for_level_in_input_table.
# The WHERE clause is used to join the two subqueries to obtain the result.
if grp_dict:
grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
grp_filter = ''
subquery = """
SELECT {source_table_columns}
generate_series(1, {desired_count_col}::int) AS _i,
((random()*({actual_count_col}-1)+1)::int) AS {q1_row_no}
q({class_col_tmp}, {desired_count_col}, {actual_count_col})
) q1,
row_number() OVER(PARTITION BY {class_col}) AS {q2_row_no}
) q2
WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') AND
q1.{q1_row_no} = q2.{q2_row_no}
""".format(null_level_val=NULL_IDENTIFIER, **locals())
return subquery
# ------------------------------------------------------------------------------
def balance_sample(schema_madlib, source_table, output_table, class_col,
class_sizes, output_table_size, grouping_cols,
with_replacement, keep_null, **kwargs):
Balance sampling function
@param schema_madlib Schema that MADlib is installed on.
@param source_table Input table name.
@param output_table Output table name.
@param class_col Name of the column containing the class to be
@param class_sizes Parameter to define the size of the different
class values.
@param output_table_size Desired size of the output data set.
@param grouping_cols The columns that define the grouping.
@param with_replacement The sampling method.
@param keep_null Flag to include rows with class level values
NULL. Default is False.
with MinWarning("warning"):
desired_sample_per_class = unique_string(desp='desired_sample_per_class')
desired_counts = unique_string(desp='desired_counts')
# set all default values
if not class_sizes:
class_sizes = UNIFORM
if not with_replacement:
with_replacement = False
keep_null = False if not keep_null else True
if class_sizes:
class_sizes = class_sizes.strip()
_validate_strs(source_table, output_table, class_col,
output_table_size, grouping_cols)
# If keep_null=False, create a view of the input table ignoring NULL
# values for class levels.
if keep_null:
new_source_table = source_table
new_source_table = unique_string(desp='source_table')
CREATE VIEW {new_source_table} AS
SELECT * FROM {source_table}
WHERE {class_col} IS NOT NULL
# class_sizes can be of two forms:
# 1. A string describing sampling strategy (as described in
# _validate_and_get_sampling_strategy).
# In this case, 'sampling_strategy_str' is set to one of
# 2. Class sizes for all (or a subset) of the class levels
# In this case, sampling_strategy_str = None and parsed_class_sizes
# is used for the sampling.
parsed_class_sizes = extract_keyvalue_params(class_sizes,
# GREENPLUM 4.3.X does not support bool-to-text cast which is relied
# upon for class_col in multiple queries. The explicit_bool_to_text
# function wraps the class_col with a MADlib function that provides the
# cast just for those platforms that don't provide the cast. For
# platforms that provide the cast, class_col is unchanged below.
class_col = explicit_bool_to_text(source_table,
distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
distinct_levels = collate_plpy_result(plpy.execute(distinct_sql))['levels']
if not parsed_class_sizes:
sampling_strategy_str = _validate_and_get_sampling_strategy(
class_sizes, output_table_size)
sampling_strategy_str = None
for each_level, each_class_size in parsed_class_sizes.items():
_assert(each_level in distinct_levels,
"Sample: Invalid class value specified ({0})".
each_class_size = int(each_class_size)
_assert(each_class_size >= 1,
"Sample: Class size has to be greater than zero")
parsed_class_sizes[each_level] = each_class_size
except ValueError:
plpy.error("Sample: Invalid value for class_sizes ({0})".
# Get the number of rows to be sampled for each class level, based on
# the input table, class_sizes, and output_table_size params. This also
# includes info about the resulting sampling strategy, i.e., one of
grp_col_str, grp_cols = get_grouping_col_str(
schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
source_table, grouping_cols)
actual_grp_level_counts = _get_level_frequency_distribution(
new_source_table, class_col, grp_col_str)
target_grp_class_sizes = _get_target_level_counts(
sampling_strategy_str, parsed_class_sizes,
actual_grp_level_counts, output_table_size)
is_output_created = False
grp_cols_list = grp_col_str.split(',') if grp_cols else []
# for each group
for grp_vals, actual_level_counts in actual_grp_level_counts.items():
target_class_sizes = target_grp_class_sizes[grp_vals]
nosample_level_counts) = _get_sampling_strategy_counts(target_class_sizes)
# grp_dict represents each grouping column and its value
# for current iteration
grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None
# Get subqueries for each sampling strategy and union together in
# one big query.
# NOSAMPLE are levels that are to be retained in output without sampling
nosample_subquery = _get_nosample_subquery(
new_source_table, class_col, nosample_level_counts.keys(), grp_dict)
# OVERSAMPLE are levels that to be sampled more than existing count
# (always with replacement)
oversample_subquery = _get_with_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, oversample_level_counts, grp_dict)
# UNDERSAMPLE are levels that are to be sampled to less than
# existing count. Undersampling supports both with and without
# replacement.
if with_replacement:
undersample_subquery = _get_with_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, undersample_level_counts, grp_dict)
undersample_subquery = _get_without_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, undersample_level_counts, grp_dict)
# Merge the three subqueries using a UNION ALL clause.
sampling_queries = (undersample_subquery, oversample_subquery, nosample_subquery)
union_all_subquery = ' UNION ALL '.join(
for subquery in sampling_queries if subquery])
# Populate the output table
if is_output_created:
table_header = "INSERT INTO {0}".format(output_table)
table_header = "CREATE TABLE {0} AS ".format(output_table)
is_output_created = True
final_query = """
SELECT row_number() OVER() AS {id_col_name}, *
) union_query
""".format(id_col_name=NEW_ID_COLUMN, **locals())
# end of grouping loop
if not keep_null:
plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table))
# ------------------------------------------------------------------------------
def _validate_strs(source_table, output_table, class_col,
output_table_size, grouping_cols):
_assert(source_table and table_exists(source_table),
"Sample: Source table ({source_table}) does not exist.".format(**locals()))
_assert(not table_is_empty(source_table),
"Sample: Source table ({source_table}) is empty.".format(**locals()))
"Sample: Output table name is missing.".format(**locals()))
_assert(not table_exists(output_table),
"Sample: Output table ({output_table}) already exists.".format(**locals()))
"Sample: Class column name is missing.".format(**locals()))
_assert(columns_exist_in_table(source_table, [class_col]),
("""Sample: Class column ({class_col}) does not exist in""" +
""" table ({source_table}).""").format(**locals()))
_assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
("""Sample: Please ensure the source table ({0})""" +
""" does not contain a column named {1}""").format(source_table, NEW_ID_COLUMN))
_assert((not output_table_size) or (output_table_size > 0),
"Sample: Invalid output table size ({output_table_size}).".format(
# ------------------------------------------------------------------------------
def balance_sample_help(schema_madlib, message, **kwargs):
Help function for balance_sample
@param schema_madlib
@param message: string, Help message string
@param kwargs
String. Help/usage information
if not message:
help_string = """
Given a table with varying set of records for each class label,
this function will create an output table with a varying types (by
default: uniform) of sampling distributions of each class label. It is
possible to use with or without replacement sampling methods, specify
different proportions of each class, multiple grouping columns and/or
output table size.
For more details on function usage:
SELECT {schema_madlib}.balance_sample('usage');
elif message.lower() in ['usage', 'help', '?']:
help_string = """
Given a table, stratified sampling returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.
SELECT {schema_madlib}.balance_sample(
source_table TEXT, -- Input table name.
output_table TEXT, -- Output table name.
class_col TEXT, -- Name of column containing the class to be
-- balanced.
class_size TEXT, -- (Default: NULL) Parameter to define the size
-- of the different class values.
output_table_size INTEGER, -- (Default: NULL) Desired size of the output
-- data set.
grouping_cols TEXT, -- (Default: NULL) The columns columns that
-- defines the grouping.
with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
keep_null BOOLEAN -- (Default: FALSE) Consider class levels with
NULL values or not.
If class_size is NULL, the source table is uniformly sampled.
If output_table_size is NULL, the resulting output table size will depend on
the settings for the ‘class_size’ parameter. It is ignored if ‘class_size’
parameter is set to either ‘oversample’ or ‘undersample’.
If grouping_cols is NULL, the whole table is treated as a single group and
sampled accordingly.
If with_replacement is TRUE, each sample is independent (the same row may
be selected in the sample set more than once). Else (if with_replacement
is FALSE), a row can be selected at most once.
The output_table would contain the required number of samples, along with a
new column named __madlib_id__, that contain unique numbers for all
sampled rows.
help_string = "No such option. Use {schema_madlib}.balance_sample()"
return help_string.format(schema_madlib=schema_madlib)
import unittest
class UtilitiesTestCase(unittest.TestCase):
Comment "import plpy" and replace plpy.error calls with appropriate
Python Exceptions to successfully run the test cases
def setUp(self):
self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}}
self.level1a = 'a'
self.level1a_cnt1 = 15
self.level1a_cnt2 = 25
self.level1a_cnt3 = 20
self.sampling_strategy_str0 = ''
self.sampling_strategy_str1 = 'uniform'
self.sampling_strategy_str2 = 'oversample'
self.sampling_strategy_str3 = 'undersample'
self.user_specified_class_size0 = ''
self.user_specified_class_size1 = {'a': 25, 'b': 25}
self.user_specified_class_size2 = {'b': 25}
self.user_specified_class_size3 = {'a': 30}
self.output_table_size1 = None
self.output_table_size2 = 60
# self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100}
def test__choose_strategy(self):
self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25))
self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25))
self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25))
def test__get_target_level_counts(self):
# Test cases for user defined class level samples, without output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}},
# Test cases for user defined class level samples, with output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}},
self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}},
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}},
# Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}},
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}},
self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
# Test cases for UNIFORM with output table size
self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
def test__get_sampling_strategy_specific_dict(self):
# Test cases for getting sampling strategy specific counts
target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}
target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)}
target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 'c': (25, NOSAMPLE)}
self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
self.assertEqual(({'b': 25}, {'a': 25}, {}),
self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
if __name__ == '__main__':