| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from utilities.control import MinWarning |
| from utilities.utilities import _assert |
| from utilities.utilities import _check_groups |
| from utilities.utilities import get_table_qualified_col_str |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import add_postfix |
| from utilities.utilities import unique_string |
| from utilities.utilities import split_quoted_delimited_str |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import get_cols |
| |
| m4_changequote(` <!', `!>') |
| |
| |
| def _get_sql_string(str): |
| if str: |
| return "'" + str + "'" |
| return "NULL" |
| |
| |
| def train_test_split(schema_madlib, source_table, output_table, train_proportion, |
| test_proportion, grouping_cols, target_cols, with_replacement, |
| separate_output_tables, **kwargs): |
| """ |
| test train split function |
| Args: |
| @param source_table Input table name. |
| @param output_table Output table name. |
| @param train_proportion The ratio of training data to the entire |
| input table |
| @param test_proportion The ratio of test data to the entire |
| input table |
| @param grouping_cols (Default: NULL) The columns to distinguish |
| each strata. |
| @param target_cols (Default: NULL) The columns to include in |
| the output. |
| @param with_replacement (Default: FALSE) The sampling method. |
| @param separate_output_tables (Default: FALSE) Create two output tables, |
| <output_table>_train and <output_table>_test. |
| Otherwise one output table is created with |
| and additional column 'split' which takes the |
| value 0 for test and 1 for training. |
| |
| """ |
| with MinWarning("warning"): |
| if test_proportion is None: |
| test_proportion = 1 - train_proportion |
| validate_strs(source_table, output_table, train_proportion, test_proportion, |
| split_quoted_delimited_str(grouping_cols), target_cols, |
| with_replacement) |
| grouping_cols = _get_sql_string(grouping_cols) |
| target_cols = _get_sql_string(target_cols) |
| with_replacement = with_replacement or "False" |
| strat_query = """ |
| SELECT {schema_madlib}.stratified_sample( |
| '{strat_source_table}', |
| '{strat_out_table}', |
| '{strat_proportion}', |
| {strat_grouping_cols}, |
| {strat_target_cols}, |
| {strat_with_replacement} |
| ) |
| """ |
| strat_out_table = unique_string() |
| q = strat_query.format( |
| schema_madlib=schema_madlib, |
| strat_source_table=source_table, |
| strat_out_table=strat_out_table, |
| strat_proportion=train_proportion + test_proportion, |
| strat_grouping_cols=grouping_cols, |
| strat_with_replacement=with_replacement, |
| strat_target_cols=target_cols |
| ) |
| plpy.execute(q) |
| test_table = add_postfix(output_table, "_test") |
| train_table = add_postfix(output_table, "_train") |
| if not separate_output_tables: |
| test_table = unique_string() |
| train_table = unique_string() |
| test_query = strat_query.format( |
| schema_madlib=schema_madlib, |
| strat_source_table=strat_out_table, |
| strat_out_table=test_table, |
| strat_proportion=(test_proportion / |
| (train_proportion + test_proportion)), |
| strat_grouping_cols=grouping_cols, |
| strat_with_replacement=False, |
| strat_target_cols=target_cols |
| ) |
| plpy.execute(test_query) |
| train_query = """ |
| CREATE TABLE {train_table} AS |
| SELECT * FROM {strat_out_table} |
| EXCEPT ALL |
| SELECT * FROM {test_table} |
| """.format(train_table=train_table, |
| strat_out_table=strat_out_table, |
| test_table=test_table) |
| plpy.execute(train_query) |
| clean_up_tables = [strat_out_table] |
| if not separate_output_tables: |
| union_query = """ |
| CREATE TABLE {output_table} AS |
| SELECT *,0 AS split FROM {test_table} |
| UNION ALL |
| SELECT *,1 AS split FROM {train_table} |
| """.format(output_table=output_table, |
| test_table=test_table, |
| train_table=train_table) |
| plpy.execute(union_query) |
| clean_up_tables += [train_table, test_table] |
| clean_up_query = """ |
| DROP TABLE IF EXISTS {clean_up_tables} |
| """.format(clean_up_tables=",".join(clean_up_tables)) |
| plpy.execute(clean_up_query) |
| return |
| |
| |
| def validate_strs(source_table, output_table, train_proportion, test_proportion, glist, target_cols, with_replacement): |
| |
| _assert(output_table and output_table.strip().lower() not in ('null', ''), |
| "Sample: Invalid output table name {output_table}!".format(**locals())) |
| _assert(not table_exists(output_table), |
| "Sample: Output table already exists!".format(**locals())) |
| |
| _assert(source_table and source_table.strip().lower() not in ('null', ''), |
| "Sample: Invalid Source table name!".format(**locals())) |
| _assert(table_exists(source_table), |
| "Sample: Source table ({source_table}) is missing!".format(**locals())) |
| _assert(not table_is_empty(source_table), |
| "Sample: Source table ({source_table}) is empty!".format(**locals())) |
| |
| for proportion in [train_proportion, test_proportion]: |
| _assert(proportion > 0 and proportion < 1, |
| "Sample: Proportions aren't in the range (0,1)!") |
| if not with_replacement: |
| _assert(train_proportion + test_proportion <= 1, |
| "Sample: Proportions add up to greater than 1!") |
| |
| if glist is not None: |
| _assert(columns_exist_in_table(source_table, glist), |
| ("""Sample: Not all columns from {glist} are present in source""" + |
| """ table ({source_table}).""").format(**locals())) |
| |
| if not (target_cols is None or target_cols is '*'): |
| tlist = split_quoted_delimited_str(target_cols) |
| _assert(columns_exist_in_table(source_table, tlist), |
| ("""Sample: Not all columns from {target_cols} are present in""" + |
| """ edge table ({source_table})""").format(**locals())) |
| return |
| |
| |
| def train_test_split_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for train_test_split |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if not message: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| SUMMARY |
| ----------------------------------------------------------------------- |
| |
| Given a table, train_test_split returns a random sample of the |
| table for testing and training. It is possible to use with or without |
| replacement sampling methods, specify a set of target columns, and a |
| set of grouping columns, in which case, stratified sampling will be |
| performed. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.train_test_split('usage'); |
| """ |
| elif message.lower() in ['usage', 'help', '?']: |
| help_string = """ |
| |
| Given a table, test train split returns a proportion of records for |
| each group (strata). It is possible to use with or without replacement |
| sampling methods, specify a set of target columns, and assume the |
| whole table is a single strata. |
| |
| ---------------------------------------------------------------------------- |
| USAGE |
| ---------------------------------------------------------------------------- |
| |
| SELECT {schema_madlib}.train_test_split( |
| source_table TEXT, -- Name of the table containing the input data. |
| output_table TEXT, -- Output table name. |
| train_proportion FLOAT8, -- The ratio of train sample size to the |
| -- number of records. |
| test_proportion FLOAT8, -- The ratio of test sample size to the |
| -- number of records. |
| grouping_cols TEXT -- (Default: NULL) The columns to distinguish |
| -- each strata. |
| target_cols TEXT, -- (Default: NULL) The columns to include in |
| -- the output. |
| with_replacement BOOLEAN -- (Default: FALSE) The sampling method. |
| separate_output_tables |
| BOOLEAN -- (Default: FALSE) Separate the output table |
| -- into $output_table$_train and |
| -- $output_table$_test, otherwise, the split |
| -- column in output_table will identify 1 for |
| -- train set and 0 for test set. |
| |
| If grouping_cols is NULL, the whole table is treated as a single group and |
| sampled accordingly. |
| |
| If target_cols is NULL or '*', all of the columns will be included in the |
| output table. |
| |
| If with_replacement is TRUE, each sample is independent (the same row may |
| be selected in the sample set more than once). Else (if with_replacement |
| is FALSE), a row can be selected at most once. |
| ); |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.train_test_split()" |
| |
| return help_string.format(schema_madlib=schema_madlib) |