blob: 391d1b4e4fc1375faf268e5d983ea40ae2d13230 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file EXCEPT in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# m4_changequote(`<!', `!>')
import math
from collections import defaultdict
if __name__ != "__main__":
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import collate_plpy_result
from utilities.utilities import get_grouping_col_str
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import explicit_bool_to_text
from utilities.validate_args import get_cols
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
else:
# Used only for Unit Testing
# FIXME: repeating a function from utilities that is needed by the unit test.
# This should be removed once a unittest framework in used for testing.
import random
import time
def unique_string(desp='', **kwargs):
"""
Generate random remporary names for temp table and other names.
It has a SQL interface so both SQL and Python functions can call it.
"""
r1 = random.randint(1, 100000000)
r2 = int(time.time())
r3 = int(time.time()) % random.randint(1, 100000000)
u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
return u_string
# ------------------------------------------------------------------------------
UNIFORM = 'uniform'
UNDERSAMPLE = 'undersample'
OVERSAMPLE = 'oversample'
NOSAMPLE = 'nosample'
NEW_ID_COLUMN = '__madlib_id__'
NULL_IDENTIFIER = '__madlib_null_id__'
def _get_level_frequency_distribution(source_table, class_col,
grp_by_cols=None):
""" Count the number of rows for each class, partitioned by the grp_by_cols
Returns a dict containing the number of rows associated with each class
level. Each class level count is converted to a string using ::text.
None is a valid key in this dict, capturing NULL value in the database.
"""
if grp_by_cols and grp_by_cols.lower() != 'null':
is_grouping = True
grp_by_cols_comma = grp_by_cols + ', '
array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS group_values, "
else:
is_grouping = False
grp_by_cols_comma = array_grp_by_cols_comma = ""
# In below query, the inner query groups the data using grp_by_cols + classes
# and obtains the count for each combination. The outer query then groups
# again by the grp_by_cols to collect the classes and counts in an array.
query_result = plpy.execute("""
SELECT
-- For each group get the classes and their rows counts
{grp_identifier} as group_values,
array_agg(classes) as classes,
array_agg(class_count) as class_count
FROM(
-- for each group and class combination present in source table
-- get the count of rows for that combination
SELECT
{array_grp_by_cols_comma}
({class_col})::TEXT AS classes,
count(*) AS class_count
FROM {source_table}
GROUP BY {grp_by_cols_comma} ({class_col})
) q
{meta_grp_by}
""".format(grp_identifier="group_values" if is_grouping else "NULL",
meta_grp_by="GROUP BY group_values" if is_grouping else "",
**locals()))
if (len(query_result) > 1) != is_grouping:
# if is_grouping then query_result should have more than 1 row
# if not is_grouping then query_result should have only 1 row
raise RuntimeError("Balance sample: Error during frequency level distribution")
actual_grp_level_counts = {}
for each_row in query_result:
# group_values is a list for each row; convert it to a tuple to use as
# key in a dictionary
grp = tuple(each_row['group_values']) if is_grouping else None
grp_levels, grp_counts = each_row['classes'], each_row['class_count']
actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts))
return actual_grp_level_counts
# ------------------------------------------------------------------------------
def _validate_and_get_sampling_strategy(sampling_strategy_str,
output_table_size,
default=UNIFORM):
""" Returns the sampling strategy based on the class_sizes input param.
@param sampling_strategy_str The sampling strategy specified by the
user (class_sizes param)
@returns:
Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
"""
if not sampling_strategy_str:
sampling_strategy_str = default
else:
if len(sampling_strategy_str) < 3:
# Require at least 3 characters since UNIFORM and UNDERSAMPLE have
# common prefix substring
plpy.error("Sample: Invalid class_sizes parameter")
supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE)
try:
# allow user to specify a prefix substring of
# supported strategies.
sampling_strategy_str = next(x for x in supported_strategies
if x.startswith(sampling_strategy_str.lower()))
except StopIteration:
# next() returns a StopIteration if no element found
plpy.error("Sample: Invalid class_sizes parameter: "
"{0}. Supported class_size parameters are ({1})".
format(sampling_strategy_str,
','.join(sorted(supported_strategies))))
_assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
(sampling_strategy_str.find('=') > 0),
"Sample: Invalid class_sizes parameter: "
"{0}. Supported class_size parameters are ({1})".
format(sampling_strategy_str,
','.join(sorted(supported_strategies))))
_assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
"Sample: Cannot set output_table_size with oversampling.")
_assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
"Sample: Cannot set output_table_size with undersampling.")
return sampling_strategy_str
# ------------------------------------------------------------------------------
def _choose_strategy(actual_count, desired_count):
""" Choose sampling strategy by comparing actual and desired sample counts
@param actual_count: Actual number of samples for some level
@param desired_count: Desired number of sample for the level
@returns:
Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
"""
# OVERSAMPLE when the actual count is less than the desired count
# UNDERSAMPLE when the actual count is more than the desired count
# If the actual count for a class level is the same as desired count, then
# we could potentially return the input rows as is. This, however,
# precludes the case of bootstrapping (i.e. returning same number of rows
# but after sampling with replacement). Hence, we treat the actual=desired
# as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it provides
# both 'with' and 'without' replacement (OVERSAMPLE is always with
# replacement and NOSAMPLE is always without replacement)
if actual_count < desired_count:
return OVERSAMPLE
else:
return UNDERSAMPLE
# -------------------------------------------------------------------------
def _get_desired_target_level_counts(desired_level_counts,
actual_grp_level_counts,
output_table_size):
""" Return the target counts for each group and each class level in the group
This function is specifically used when the user has provided the
targets for all (or a subset) of the levels. The strategy of either
under or oversampling for each class level is chosen based on the
desired number of counts for a level, and the actual number of counts
already present in the input table. This calculation is performed for
each group (or once if no grouping is used).
@returns: dict: The key of the dictionary is the group value and the
value is another dictionary. The inner dictionary gives
the target count for each level present in the group.
"""
target_grp_level_counts = defaultdict(dict)
n_grp_values = len(actual_grp_level_counts)
for each_grp, actual_level_counts in actual_grp_level_counts.items():
for each_level, desired_count in desired_level_counts.items():
# actual_grp_level_counts are group specific, whereas
# desired_level_counts are evenly split among all groups
try:
per_grp_desired_count = math.ceil(float(desired_count) / n_grp_values)
sample_strategy = _choose_strategy(actual_level_counts[each_level],
per_grp_desired_count)
target_grp_level_counts[each_grp][each_level] = (
per_grp_desired_count, sample_strategy)
except KeyError:
plpy.error("Balance sample: Desired class level ({0}) not present in the "
"data for each group.".format(each_level))
# desired levels could contain just a subset of all levels. For the remaining
# levels in actual_level_counts, compute the desired counts
remaining_levels = (set(actual_level_counts.keys()) -
set(desired_level_counts.keys()))
# if 'output_table_size' = NULL, remaining level counts remain as is
# if 'output_table_size' = <Integer>, divide remaining count
# uniformly among reamining levels
if output_table_size:
# output_table_size is for the whole table and should be split evenly
# between the groups
remaining_rows = math.ceil(float(output_table_size -
sum(desired_level_counts.values())) /
n_grp_values)
if remaining_rows > 0:
# Uniformly distribute the remaining class levels
rows_per_level = math.ceil(float(remaining_rows) /
len(remaining_levels))
for each_level in remaining_levels:
sample_strategy = _choose_strategy(
actual_level_counts[each_level], rows_per_level)
target_grp_level_counts[each_grp][each_level] = (
rows_per_level, sample_strategy)
else:
# When output_table_size is unspecified, rows from the input table
# are sampled as is for remaining class levels. This is called as the
# NOSAMPLE strategy.
for each_level in remaining_levels:
target_grp_level_counts[each_grp][each_level] = (
actual_level_counts[each_level], NOSAMPLE)
return target_grp_level_counts
# -------------------------------------------------------------------------
def _get_supported_target_level_counts(sampling_strategy_str,
actual_grp_level_counts,
output_table_size):
""" Returns the target level counts for all levels when the class_size param
is one of [uniform, undersample, oversample]. The strategy of
(under)oversampling for a specific class level is chosen based on the
computed number of counts for a level, and the actual number of counts
already present in the input table.
"""
target_grp_level_counts = defaultdict(dict)
n_grp_values = len(actual_grp_level_counts)
for each_grp, actual_level_counts in actual_grp_level_counts.items():
if sampling_strategy_str == UNIFORM:
# UNIFORM: Ensure all level counts are same
if output_table_size:
# Ignore actual counts for computing target sizes
# if output_table_size is specified
total_ = float(output_table_size) / n_grp_values
else:
total_ = sum(actual_level_counts.values())
target_size_per_level = math.ceil(float(total_) /
len(actual_level_counts))
else:
# UNDERSAMPLE: Ensure all level counts are same as the minimum count
# OVERSAMPLE: Ensure all level counts are same as the maximum count
if sampling_strategy_str == UNDERSAMPLE:
target_size_per_level = min(actual_level_counts.values())
elif sampling_strategy_str == OVERSAMPLE:
target_size_per_level = max(actual_level_counts.values())
else:
raise RuntimeError("Balance sample: Invalid "
"sampling_strategy_str encountered")
for each_level, actual_count in actual_level_counts.items():
sample_strategy = _choose_strategy(actual_count, target_size_per_level)
target_grp_level_counts[each_grp][each_level] = (
target_size_per_level, sample_strategy)
return target_grp_level_counts
# -------------------------------------------------------------------------
def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
actual_grp_level_counts, output_table_size):
"""
@param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
This is 'None' only if this is user-defined, i.e.,
a comma separated list of class levels and number
of rows desired pairs.
@param desired_level_counts: Dict that is defined only when the previous arg
sampling_strategy_str is None. This dict would
then contain the class levels and the
corresponding number of rows specified by
the user.
@param actual_grp_level_counts: Dictionary that provides for each group the
the count of number of rows for each class
present in the group
@param output_table_size: Size of the desired output table (NULL or Integer)
@returns:
Dict. Number of samples to be drawn, and the sampling strategy to be
used for each class level.
"""
if not sampling_strategy_str:
# This case implies user has provided a desired count for one or more
# levels. Counts for rest of the levels depend on 'output_table_size'.
target_level_counts = _get_desired_target_level_counts(
desired_level_counts, actual_grp_level_counts, output_table_size)
else:
# This case imples the user has chosen one of
# [uniform, undersample, oversample] for the class_size parameter.
target_level_counts = _get_supported_target_level_counts(
sampling_strategy_str, actual_grp_level_counts, output_table_size)
return target_level_counts
# -------------------------------------------------------------------------
def _get_sampling_strategy_counts(target_class_sizes):
""" Return three dicts, one each for undersampling, oversampling, and
nosampling. The dict contains the number of samples to be drawn for
each class level.
"""
undersample_level_counts = {}
oversample_level_counts = {}
nosample_level_counts = {}
for level, (count, strategy) in target_class_sizes.items():
if strategy == UNDERSAMPLE:
undersample_level_counts[level] = count
elif strategy == OVERSAMPLE:
oversample_level_counts[level] = count
else:
nosample_level_counts[level] = count
return (undersample_level_counts, oversample_level_counts, nosample_level_counts)
# ------------------------------------------------------------------------------
def _get_nosample_subquery(source_table, class_col, nosample_levels,
grp_dict=None):
""" Return the subquery for fetching all rows as is from the input table
for specific class levels.
"""
if not nosample_levels:
return ''
nosample_level_str = ','.join(["'{0}'".format(level)
for level in nosample_levels if level])
if grp_dict:
grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
else:
grp_filter = ''
return """
SELECT *
FROM {source_table}
WHERE ({class_col} in ({nosample_level_str}) OR
{class_col} IS NULL)
{grp_filter}
""".format(**locals())
# ------------------------------------------------------------------------------
def _get_without_replacement_subquery(schema_madlib, source_table, class_col,
actual_level_counts, desired_level_counts,
grp_dict=None):
""" Return the subquery for sampling without replacement for specific
class levels.
"""
if not desired_level_counts:
return ''
class_col_tmp = unique_string(desp='class_col')
row_number_col = unique_string(desp='row_number')
desired_count_col = unique_string(desp='desired_count')
source_table_columns = ','.join(get_cols(source_table))
null_value_string = "'{0}'".format(NULL_IDENTIFIER)
desired_level_count_pairs = (','.join("({0}, {1})".
format("'{0}'::text".format(k) if k else null_value_string, v)
for k, v in desired_level_counts.items()))
desired_level_counts_str = "VALUES " + desired_level_count_pairs
# Subquery q2 is used to figure out the number of rows to select for each
# class level. That is used with row_number to order the input rows randomly,
# and trim the results to the desired number of rows for each class level.
# q1:
# The FROM clause contains information that can be used to figure out the
# number of rows to generate for a class level. The tuple basically
# contains: (class_level, the_desired_number_of_rows)
if grp_dict:
grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
else:
grp_filter = ''
subquery = """
SELECT {source_table_columns}
FROM (
SELECT {source_table_columns},
row_number() OVER (PARTITION BY {class_col} ORDER BY random()) AS {row_number_col},
{desired_count_col}
FROM (
SELECT {source_table_columns},
{desired_count_col}
FROM
{source_table} s,
({desired_level_counts_str})
q1({class_col_tmp}, {desired_count_col})
WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_identifier}')
{grp_filter}
) q2
) q3
WHERE {row_number_col} <= {desired_count_col}
""".format(null_identifier=NULL_IDENTIFIER, **locals())
return subquery
# ------------------------------------------------------------------------------
def _get_with_replacement_subquery(schema_madlib, source_table, class_col,
actual_level_counts, desired_level_counts,
grp_dict=None):
""" Return the query for sampling with replacement for specific class
levels. Always used for oversampling since oversampling will always need
to use replacement. Used for under sampling only if with_replacement
flag is set to TRUE.
"""
if not desired_level_counts:
return ''
source_table_columns = ','.join(get_cols(source_table))
class_col_tmp = unique_string(desp='class_col_with_rep')
desired_count_col = unique_string(desp='desired_count_with_rep')
actual_count_col = unique_string(desp='actual_count')
q1_row_no = unique_string(desp='q1_row')
q2_row_no = unique_string(desp='q2_row')
null_value_string = "'{0}'".format(NULL_IDENTIFIER)
desired_and_actual_count = (','.join("({0}, {1}, {2})".
format("'{0}'::text".format(k) if k else null_value_string,
v, actual_level_counts[k])
for k, v in desired_level_counts.items()))
desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count
# q1 and q2 are two sub queries we create to generate the required number of
# rows per class level.
# q1:
# The FROM clause contains information that can be used to figure out the
# number of rows to generate for a class level. The tuple basically
# contains:
# (class_level, the_desired_number_of_rows, actual_rows_for_level_in_input_table)
# The SELECT clause uses generate series to duplicate a row {q1_row_no}
# of times, which is a value between 1 and actual_rows_for_level_in_input_table
# q2:
# Replicates the source_table with row IDs starting from 1 through
# actual_rows_for_level_in_input_table.
#
# The WHERE clause is used to join the two subqueries to obtain the result.
if grp_dict:
grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
else:
grp_filter = ''
subquery = """
SELECT {source_table_columns}
FROM
(
SELECT
{class_col_tmp},
generate_series(1, {desired_count_col}::int) AS _i,
((random()*({actual_count_col}-1)+1)::int) AS {q1_row_no}
FROM
({desired_and_actual_level_count_str})
q({class_col_tmp}, {desired_count_col}, {actual_count_col})
) q1,
(
SELECT
*,
row_number() OVER(PARTITION BY {class_col}) AS {q2_row_no}
FROM
{source_table}
{grp_filter}
) q2
WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') AND
q1.{q1_row_no} = q2.{q2_row_no}
""".format(null_level_val=NULL_IDENTIFIER, **locals())
return subquery
# ------------------------------------------------------------------------------
def balance_sample(schema_madlib, source_table, output_table, class_col,
class_sizes, output_table_size, grouping_cols,
with_replacement, keep_null, **kwargs):
"""
Balance sampling function
Args:
@param schema_madlib Schema that MADlib is installed on.
@param source_table Input table name.
@param output_table Output table name.
@param class_col Name of the column containing the class to be
balanced.
@param class_sizes Parameter to define the size of the different
class values.
@param output_table_size Desired size of the output data set.
@param grouping_cols The columns that define the grouping.
@param with_replacement The sampling method.
@param keep_null Flag to include rows with class level values
NULL. Default is False.
"""
with MinWarning("warning"):
desired_sample_per_class = unique_string(desp='desired_sample_per_class')
desired_counts = unique_string(desp='desired_counts')
# set all default values
if not class_sizes:
class_sizes = UNIFORM
if not with_replacement:
with_replacement = False
keep_null = False if not keep_null else True
if class_sizes:
class_sizes = class_sizes.strip()
_validate_strs(source_table, output_table, class_col,
output_table_size, grouping_cols)
# If keep_null=False, create a view of the input table ignoring NULL
# values for class levels.
if keep_null:
new_source_table = source_table
else:
new_source_table = unique_string(desp='source_table')
plpy.execute("""
CREATE VIEW {new_source_table} AS
SELECT * FROM {source_table}
WHERE {class_col} IS NOT NULL
""".format(**locals()))
# class_sizes can be of two forms:
# 1. A string describing sampling strategy (as described in
# _validate_and_get_sampling_strategy).
# In this case, 'sampling_strategy_str' is set to one of
# [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
# 2. Class sizes for all (or a subset) of the class levels
# In this case, sampling_strategy_str = None and parsed_class_sizes
# is used for the sampling.
parsed_class_sizes = extract_keyvalue_params(class_sizes,
allow_duplicates=False,
lower_case_names=False)
# GREENPLUM 4.3.X does not support bool-to-text cast which is relied
# upon for class_col in multiple queries. The explicit_bool_to_text
# function wraps the class_col with a MADlib function that provides the
# cast just for those platforms that don't provide the cast. For
# platforms that provide the cast, class_col is unchanged below.
class_col = explicit_bool_to_text(source_table,
[class_col],
schema_madlib)[0]
distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
format(class_col,
source_table))
distinct_levels = collate_plpy_result(plpy.execute(distinct_sql))['levels']
if not parsed_class_sizes:
sampling_strategy_str = _validate_and_get_sampling_strategy(
class_sizes, output_table_size)
else:
sampling_strategy_str = None
try:
for each_level, each_class_size in parsed_class_sizes.items():
_assert(each_level in distinct_levels,
"Sample: Invalid class value specified ({0})".
format(each_level))
each_class_size = int(each_class_size)
_assert(each_class_size >= 1,
"Sample: Class size has to be greater than zero")
parsed_class_sizes[each_level] = each_class_size
except ValueError:
plpy.error("Sample: Invalid value for class_sizes ({0})".
format(class_sizes))
# Get the number of rows to be sampled for each class level, based on
# the input table, class_sizes, and output_table_size params. This also
# includes info about the resulting sampling strategy, i.e., one of
# UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
grp_col_str, grp_cols = get_grouping_col_str(
schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
source_table, grouping_cols)
actual_grp_level_counts = _get_level_frequency_distribution(
new_source_table, class_col, grp_col_str)
target_grp_class_sizes = _get_target_level_counts(
sampling_strategy_str, parsed_class_sizes,
actual_grp_level_counts, output_table_size)
is_output_created = False
grp_cols_list = grp_col_str.split(',') if grp_cols else []
# for each group
for grp_vals, actual_level_counts in actual_grp_level_counts.items():
target_class_sizes = target_grp_class_sizes[grp_vals]
(undersample_level_counts,
oversample_level_counts,
nosample_level_counts) = _get_sampling_strategy_counts(target_class_sizes)
# grp_dict represents each grouping column and its value
# for current iteration
grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None
# Get subqueries for each sampling strategy and union together in
# one big query.
# NOSAMPLE are levels that are to be retained in output without sampling
nosample_subquery = _get_nosample_subquery(
new_source_table, class_col, nosample_level_counts.keys(), grp_dict)
# OVERSAMPLE are levels that to be sampled more than existing count
# (always with replacement)
oversample_subquery = _get_with_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, oversample_level_counts, grp_dict)
# UNDERSAMPLE are levels that are to be sampled to less than
# existing count. Undersampling supports both with and without
# replacement.
if with_replacement:
undersample_subquery = _get_with_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, undersample_level_counts, grp_dict)
else:
undersample_subquery = _get_without_replacement_subquery(
schema_madlib, new_source_table, class_col,
actual_level_counts, undersample_level_counts, grp_dict)
# Merge the three subqueries using a UNION ALL clause.
sampling_queries = (undersample_subquery, oversample_subquery, nosample_subquery)
union_all_subquery = ' UNION ALL '.join(
['({0})'.format(subquery)
for subquery in sampling_queries if subquery])
# Populate the output table
if is_output_created:
table_header = "INSERT INTO {0}".format(output_table)
else:
table_header = "CREATE TABLE {0} AS ".format(output_table)
is_output_created = True
final_query = """
{table_header}
SELECT row_number() OVER() AS {id_col_name}, *
FROM (
{union_all_subquery}
) union_query
""".format(id_col_name=NEW_ID_COLUMN, **locals())
plpy.execute(final_query)
# end of grouping loop
if not keep_null:
plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table))
# ------------------------------------------------------------------------------
def _validate_strs(source_table, output_table, class_col,
output_table_size, grouping_cols):
_assert(source_table and table_exists(source_table),
"Sample: Source table ({source_table}) does not exist.".format(**locals()))
_assert(not table_is_empty(source_table),
"Sample: Source table ({source_table}) is empty.".format(**locals()))
_assert(output_table,
"Sample: Output table name is missing.".format(**locals()))
_assert(not table_exists(output_table),
"Sample: Output table ({output_table}) already exists.".format(**locals()))
_assert(class_col,
"Sample: Class column name is missing.".format(**locals()))
_assert(columns_exist_in_table(source_table, [class_col]),
("""Sample: Class column ({class_col}) does not exist in""" +
""" table ({source_table}).""").format(**locals()))
_assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
("""Sample: Please ensure the source table ({0})""" +
""" does not contain a column named {1}""").format(source_table, NEW_ID_COLUMN))
_assert((not output_table_size) or (output_table_size > 0),
"Sample: Invalid output table size ({output_table_size}).".format(
**locals()))
# ------------------------------------------------------------------------------
def balance_sample_help(schema_madlib, message, **kwargs):
"""
Help function for balance_sample
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Given a table with varying set of records for each class label,
this function will create an output table with a varying types (by
default: uniform) of sampling distributions of each class label. It is
possible to use with or without replacement sampling methods, specify
different proportions of each class, multiple grouping columns and/or
output table size.
For more details on function usage:
SELECT {schema_madlib}.balance_sample('usage');
"""
elif message.lower() in ['usage', 'help', '?']:
help_string = """
Given a table, stratified sampling returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.
----------------------------------------------------------------------------
USAGE
----------------------------------------------------------------------------
SELECT {schema_madlib}.balance_sample(
source_table TEXT, -- Input table name.
output_table TEXT, -- Output table name.
class_col TEXT, -- Name of column containing the class to be
-- balanced.
class_size TEXT, -- (Default: NULL) Parameter to define the size
-- of the different class values.
output_table_size INTEGER, -- (Default: NULL) Desired size of the output
-- data set.
grouping_cols TEXT, -- (Default: NULL) The columns columns that
-- defines the grouping.
with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
keep_null BOOLEAN -- (Default: FALSE) Consider class levels with
NULL values or not.
If class_size is NULL, the source table is uniformly sampled.
If output_table_size is NULL, the resulting output table size will depend on
the settings for the ‘class_size’ parameter. It is ignored if ‘class_size’
parameter is set to either ‘oversample’ or ‘undersample’.
If grouping_cols is NULL, the whole table is treated as a single group and
sampled accordingly.
If with_replacement is TRUE, each sample is independent (the same row may
be selected in the sample set more than once). Else (if with_replacement
is FALSE), a row can be selected at most once.
);
The output_table would contain the required number of samples, along with a
new column named __madlib_id__, that contain unique numbers for all
sampled rows.
"""
else:
help_string = "No such option. Use {schema_madlib}.balance_sample()"
return help_string.format(schema_madlib=schema_madlib)
import unittest
class UtilitiesTestCase(unittest.TestCase):
"""
Comment "import plpy" and replace plpy.error calls with appropriate
Python Exceptions to successfully run the test cases
"""
def setUp(self):
self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}}
self.level1a = 'a'
self.level1a_cnt1 = 15
self.level1a_cnt2 = 25
self.level1a_cnt3 = 20
self.sampling_strategy_str0 = ''
self.sampling_strategy_str1 = 'uniform'
self.sampling_strategy_str2 = 'oversample'
self.sampling_strategy_str3 = 'undersample'
self.user_specified_class_size0 = ''
self.user_specified_class_size1 = {'a': 25, 'b': 25}
self.user_specified_class_size2 = {'b': 25}
self.user_specified_class_size3 = {'a': 30}
self.output_table_size1 = None
self.output_table_size2 = 60
# self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100}
def test__choose_strategy(self):
self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25))
self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25))
self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25))
def test__get_target_level_counts(self):
# Test cases for user defined class level samples, without output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size1,
self.input_class_level_counts1,
self.output_table_size1))
self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size2,
self.input_class_level_counts1,
self.output_table_size1))
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size3,
self.input_class_level_counts1,
self.output_table_size1))
# Test cases for user defined class level samples, with output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size1,
self.input_class_level_counts1,
self.output_table_size2))
self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size2,
self.input_class_level_counts1,
self.output_table_size2))
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str0,
self.user_specified_class_size3,
self.input_class_level_counts1,
self.output_table_size2))
# Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any output table size
self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str1,
self.user_specified_class_size0,
self.input_class_level_counts1,
self.output_table_size1))
self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str2,
self.user_specified_class_size0,
self.input_class_level_counts1,
self.output_table_size1))
self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str3,
self.user_specified_class_size0,
self.input_class_level_counts1,
self.output_table_size1))
# Test cases for UNIFORM with output table size
self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
_get_target_level_counts(self.sampling_strategy_str1,
self.user_specified_class_size0,
self.input_class_level_counts1,
self.output_table_size2))
def test__get_sampling_strategy_specific_dict(self):
# Test cases for getting sampling strategy specific counts
target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}
target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)}
target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 'c': (25, NOSAMPLE)}
self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
_get_sampling_strategy_counts(target_level_counts_1))
self.assertEqual(({'b': 25}, {'a': 25}, {}),
_get_sampling_strategy_counts(target_level_counts_2))
self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
_get_sampling_strategy_counts(target_level_counts_3))
if __name__ == '__main__':
unittest.main()