| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from utilities.control import MinWarning |
| from utilities.utilities import _assert |
| from utilities.utilities import _check_groups |
| from utilities.utilities import get_table_qualified_col_str |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import unique_string |
| from utilities.utilities import split_quoted_delimited_str |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import get_cols |
| |
| m4_changequote(`<!', `!>') |
| |
| def stratified_sample(schema_madlib, source_table, output_table, proportion, |
| grouping_cols, target_cols, with_replacement, **kwargs): |
| |
| """ |
| Stratified sampling function |
| Args: |
| @param source_table Input table name. |
| @param output_table Output table name. |
| @param proportion The ratio of sample size to the number of |
| records. |
| @param grouping_cols (Default: NULL) The columns to distinguish |
| each strata. |
| @param target_cols (Default: NULL) The columns to include in |
| the output. |
| @param with_replacement (Default: FALSE) The sampling method. |
| |
| """ |
| with MinWarning("warning"): |
| label = unique_string(desp='label') |
| perc = unique_string(desp='perc') |
| |
| checkg_lp = "" |
| window = "" |
| grp_by = "" |
| grp_from_perc = "" |
| grp_comma = "" |
| glist = None |
| if grouping_cols is not None: |
| glist = split_quoted_delimited_str(grouping_cols) |
| checkg_lp = " AND " + _check_groups(label,perc,glist) |
| window = "PARTITION BY {0}".format(grouping_cols) |
| grp_by = "GROUP BY {0}".format(grouping_cols) |
| grp_from_perc = get_table_qualified_col_str(perc,glist) + " , " |
| grp_comma = grouping_cols + " , " |
| |
| validate_strs(source_table, output_table, proportion, glist, target_cols) |
| |
| if target_cols is None or target_cols is '*': |
| cols = get_cols(source_table) |
| if grouping_cols is not None: |
| cols = [item for item in cols if item not in glist] |
| target_cols = " , ".join(cols) |
| |
| plpy.execute("DROP TABLE IF EXISTS {0},{1}".format(label,perc)) |
| if not with_replacement : |
| if grouping_cols: |
| |
| # Create a random label for each record |
| sql1 = """ CREATE TEMP TABLE {label} AS ( |
| SELECT {target_cols},{grouping_cols},random() AS __samp_out_label |
| FROM {source_table})""".format(**locals()) |
| plpy.execute(sql1) |
| |
| # Find the cut-off label for the given proportion |
| sql2 = """ CREATE TEMP TABLE {perc} AS ( |
| SELECT {grouping_cols}, percentile_disc({proportion}) |
| WITHIN GROUP (ORDER BY __samp_out_label) AS __samp_out_label |
| FROM {label} GROUP BY {grouping_cols})""".format(**locals()) |
| plpy.execute(sql2) |
| |
| # Select every record that has a label under the threshold |
| sql3 = """ CREATE TABLE {output_table} AS ( |
| SELECT {grp_from_perc} {target_cols} |
| FROM {label} INNER JOIN {perc} ON ( |
| {label}.__samp_out_label <= {perc}.__samp_out_label |
| {checkg_lp}) )""".format(**locals()) |
| plpy.execute(sql3) |
| else: |
| |
| # Find the number of records to select |
| count = plpy.execute("SELECT count(*) AS count FROM {0}". |
| format(source_table))[0]['count'] |
| count = count * proportion |
| |
| # Order randomly and select the required number of records |
| sql1 = """ CREATE TABLE {output_table} AS ( |
| SELECT {target_cols} |
| FROM {source_table} |
| ORDER BY random() |
| LIMIT {count})""".format(**locals()) |
| plpy.execute(sql1) |
| else: |
| |
| # Set the row number as the label for each record |
| # OVER clause ensures that different groups have independent |
| # row_numbers |
| sql1 = """ CREATE TEMP TABLE {label} AS ( |
| SELECT {grp_comma} {target_cols}, |
| row_number() OVER ({window}) AS __samp_out_label |
| FROM {source_table})""".format(**locals()) |
| plpy.execute(sql1) |
| |
| # Generate a series of random values for each group based on their |
| # individual row counts. |
| # These random values are independent from each other and may have |
| # the same value. |
| |
| sql2 = """ CREATE TEMP TABLE {perc} AS ( |
| SELECT {grp_comma} |
| GENERATE_SERIES(0,(count*{proportion}-1)::int) AS __i, |
| ((random()*(count-1)+1)::int) AS __samp_out_label |
| FROM ( |
| SELECT {grp_comma} count(*) AS count |
| FROM {source_table} {grp_by}) AS sub) |
| """.format(**locals()) |
| plpy.execute(sql2) |
| |
| # Join the two tables to get the selected samples. |
| # If a random value is generated twice, the join will ensure that |
| # the record is selected twice |
| sql3 = """ CREATE TABLE {output_table} AS ( |
| SELECT {grp_from_perc} {target_cols} |
| FROM {label} INNER JOIN {perc} ON ( |
| {label}.__samp_out_label = {perc}.__samp_out_label |
| {checkg_lp}) )""".format(**locals()) |
| plpy.execute(sql3) |
| plpy.execute("DROP TABLE IF EXISTS {0},{1}".format(label,perc)) |
| |
| return |
| |
| def validate_strs (source_table, output_table, proportion, glist, target_cols): |
| |
| _assert(output_table and output_table.strip().lower() not in ('null', ''), |
| "Sample: Invalid output table name {output_table}!".format(**locals())) |
| _assert(not table_exists(output_table), |
| "Sample: Output table already exists!".format(**locals())) |
| |
| _assert(source_table and source_table.strip().lower() not in ('null', ''), |
| "Sample: Invalid Source table name!".format(**locals())) |
| _assert(table_exists(source_table), |
| "Sample: Source table ({source_table}) is missing!".format(**locals())) |
| _assert(not table_is_empty(source_table), |
| "Sample: Source table ({source_table}) is empty!".format(**locals())) |
| |
| _assert(proportion > 0 and proportion <= 1, |
| "Sample: Proportion isn't in the range (0,1)!") |
| |
| if glist is not None: |
| _assert(columns_exist_in_table(source_table, glist), |
| ("""Sample: Not all columns from {glist} are present in source"""+ |
| """ table ({source_table}).""").format(**locals())) |
| |
| if not (target_cols is None or target_cols is '*'): |
| tlist = split_quoted_delimited_str(target_cols) |
| _assert(columns_exist_in_table(source_table, tlist), |
| ("""Sample: Not all columns from {target_cols} are present in"""+ |
| """ edge table ({source_table})""").format(**locals())) |
| return |
| |
| |
| def stratified_sample_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for stratified_sample |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if not message: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| SUMMARY |
| ----------------------------------------------------------------------- |
| |
| Given a table, stratified sampling returns a proportion of records for |
| each group (strata). It is possible to use with or without replacement |
| sampling methods, specify a set of target columns, and assume the |
| whole table is a single strata. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.stratified_sample('usage'); |
| """ |
| elif message.lower() in ['usage', 'help', '?']: |
| help_string = """ |
| |
| Given a table, stratified sampling returns a proportion of records for |
| each group (strata). It is possible to use with or without replacement |
| sampling methods, specify a set of target columns, and assume the |
| whole table is a single strata. |
| |
| ---------------------------------------------------------------------------- |
| USAGE |
| ---------------------------------------------------------------------------- |
| |
| SELECT {schema_madlib}.stratified_sample( |
| source_table TEXT, -- Input table name. |
| output_table TEXT, -- Output table name. |
| proportion FLOAT8, -- The ratio of sample size to the number of |
| -- records. |
| grouping_cols TEXT -- (Default: NULL) The columns to distinguish |
| -- each strata. |
| target_cols TEXT, -- (Default: NULL) The columns to include in |
| -- the output. |
| with_replacement BOOLEAN -- (Default: FALSE) The sampling method. |
| |
| If grouping_cols is NULL, the whole table is treated as a single group and |
| sampled accordingly. |
| |
| If target_cols is NULL or '*', all of the columns will be included in the |
| output table. |
| |
| If with_replacement is TRUE, each sample is independent (the same row may |
| be selected in the sample set more than once). Else (if with_replacement |
| is FALSE), a row can be selected at most once. |
| ); |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.stratified_sample()" |
| |
| return help_string.format(schema_madlib=schema_madlib) |