blob: 90d12d54412c3a32aec12fe563ac7d2b3af0145a [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import _check_groups
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import split_quoted_delimited_str
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_expr_type
from utilities.validate_args import get_cols
m4_changequote(`<!', `!>')
def stratified_sample(schema_madlib, source_table, output_table, proportion,
grouping_cols, target_cols, with_replacement, **kwargs):
"""
Stratified sampling function
Args:
@param source_table Input table name.
@param output_table Output table name.
@param proportion The ratio of sample size to the number of
records.
@param grouping_cols (Default: NULL) The columns to distinguish
each strata.
@param target_cols (Default: NULL) The columns to include in
the output.
@param with_replacement (Default: FALSE) The sampling method.
"""
with MinWarning("warning"):
label = unique_string(desp='label')
perc = unique_string(desp='perc')
checkg_lp = ""
window = ""
grp_by = ""
grp_from_perc = ""
grp_comma = ""
glist = None
if grouping_cols is not None:
glist = split_quoted_delimited_str(grouping_cols)
checkg_lp = " AND " + _check_groups(label,perc,glist)
window = "PARTITION BY {0}".format(grouping_cols)
grp_by = "GROUP BY {0}".format(grouping_cols)
grp_from_perc = get_table_qualified_col_str(perc,glist) + " , "
grp_comma = grouping_cols + " , "
validate_strs(source_table, output_table, proportion, glist, target_cols)
if target_cols is None or target_cols is '*':
cols = get_cols(source_table)
if grouping_cols is not None:
cols = [item for item in cols if item not in glist]
target_cols = " , ".join(cols)
plpy.execute("DROP TABLE IF EXISTS {0},{1}".format(label,perc))
if not with_replacement :
if grouping_cols:
# Create a random label for each record
sql1 = """ CREATE TEMP TABLE {label} AS (
SELECT {target_cols},{grouping_cols},random() AS __samp_out_label
FROM {source_table})""".format(**locals())
plpy.execute(sql1)
# Find the cut-off label for the given proportion
sql2 = """ CREATE TEMP TABLE {perc} AS (
SELECT {grouping_cols}, percentile_disc({proportion})
WITHIN GROUP (ORDER BY __samp_out_label) AS __samp_out_label
FROM {label} GROUP BY {grouping_cols})""".format(**locals())
plpy.execute(sql2)
# Select every record that has a label under the threshold
sql3 = """ CREATE TABLE {output_table} AS (
SELECT {grp_from_perc} {target_cols}
FROM {label} INNER JOIN {perc} ON (
{label}.__samp_out_label <= {perc}.__samp_out_label
{checkg_lp}) )""".format(**locals())
plpy.execute(sql3)
else:
# Find the number of records to select
count = plpy.execute("SELECT count(*) AS count FROM {0}".
format(source_table))[0]['count']
count = count * proportion
# Order randomly and select the required number of records
sql1 = """ CREATE TABLE {output_table} AS (
SELECT {target_cols}
FROM {source_table}
ORDER BY random()
LIMIT {count})""".format(**locals())
plpy.execute(sql1)
else:
# Set the row number as the label for each record
# OVER clause ensures that different groups have independent
# row_numbers
sql1 = """ CREATE TEMP TABLE {label} AS (
SELECT {grp_comma} {target_cols},
row_number() OVER ({window}) AS __samp_out_label
FROM {source_table})""".format(**locals())
plpy.execute(sql1)
# Generate a series of random values for each group based on their
# individual row counts.
# These random values are independent from each other and may have
# the same value.
sql2 = """ CREATE TEMP TABLE {perc} AS (
SELECT {grp_comma}
GENERATE_SERIES(0,(count*{proportion}-1)::int) AS __i,
((random()*(count-1)+1)::int) AS __samp_out_label
FROM (
SELECT {grp_comma} count(*) AS count
FROM {source_table} {grp_by}) AS sub)
""".format(**locals())
plpy.execute(sql2)
# Join the two tables to get the selected samples.
# If a random value is generated twice, the join will ensure that
# the record is selected twice
sql3 = """ CREATE TABLE {output_table} AS (
SELECT {grp_from_perc} {target_cols}
FROM {label} INNER JOIN {perc} ON (
{label}.__samp_out_label = {perc}.__samp_out_label
{checkg_lp}) )""".format(**locals())
plpy.execute(sql3)
plpy.execute("DROP TABLE IF EXISTS {0},{1}".format(label,perc))
return
def validate_strs (source_table, output_table, proportion, glist, target_cols):
_assert(output_table and output_table.strip().lower() not in ('null', ''),
"Sample: Invalid output table name {output_table}!".format(**locals()))
_assert(not table_exists(output_table),
"Sample: Output table already exists!".format(**locals()))
_assert(source_table and source_table.strip().lower() not in ('null', ''),
"Sample: Invalid Source table name!".format(**locals()))
_assert(table_exists(source_table),
"Sample: Source table ({source_table}) is missing!".format(**locals()))
_assert(not table_is_empty(source_table),
"Sample: Source table ({source_table}) is empty!".format(**locals()))
_assert(proportion > 0 and proportion <= 1,
"Sample: Proportion isn't in the range (0,1)!")
if glist is not None:
_assert(columns_exist_in_table(source_table, glist),
("""Sample: Not all columns from {glist} are present in source"""+
""" table ({source_table}).""").format(**locals()))
if not (target_cols is None or target_cols is '*'):
tlist = split_quoted_delimited_str(target_cols)
_assert(columns_exist_in_table(source_table, tlist),
("""Sample: Not all columns from {target_cols} are present in"""+
""" edge table ({source_table})""").format(**locals()))
return
def stratified_sample_help(schema_madlib, message, **kwargs):
"""
Help function for stratified_sample
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Given a table, stratified sampling returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.
For more details on function usage:
SELECT {schema_madlib}.stratified_sample('usage');
"""
elif message.lower() in ['usage', 'help', '?']:
help_string = """
Given a table, stratified sampling returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.
----------------------------------------------------------------------------
USAGE
----------------------------------------------------------------------------
SELECT {schema_madlib}.stratified_sample(
source_table TEXT, -- Input table name.
output_table TEXT, -- Output table name.
proportion FLOAT8, -- The ratio of sample size to the number of
-- records.
grouping_cols TEXT -- (Default: NULL) The columns to distinguish
-- each strata.
target_cols TEXT, -- (Default: NULL) The columns to include in
-- the output.
with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
If grouping_cols is NULL, the whole table is treated as a single group and
sampled accordingly.
If target_cols is NULL or '*', all of the columns will be included in the
output table.
If with_replacement is TRUE, each sample is independent (the same row may
be selected in the sample set more than once). Else (if with_replacement
is FALSE), a row can be selected at most once.
);
"""
else:
help_string = "No such option. Use {schema_madlib}.stratified_sample()"
return help_string.format(schema_madlib=schema_madlib)