blob: 17388bb5ab7c84382692672983bc5237c7294c6c [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import _check_groups
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import add_postfix
from utilities.utilities import unique_string
from utilities.utilities import split_quoted_delimited_str
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_expr_type
from utilities.validate_args import get_cols
m4_changequote(` <!', `!>')
def _get_sql_string(str):
if str:
return "'" + str + "'"
return "NULL"
def train_test_split(schema_madlib, source_table, output_table, train_proportion,
test_proportion, grouping_cols, target_cols, with_replacement,
separate_output_tables, **kwargs):
"""
test train split function
Args:
@param source_table Input table name.
@param output_table Output table name.
@param train_proportion The ratio of training data to the entire
input table
@param test_proportion The ratio of test data to the entire
input table
@param grouping_cols (Default: NULL) The columns to distinguish
each strata.
@param target_cols (Default: NULL) The columns to include in
the output.
@param with_replacement (Default: FALSE) The sampling method.
@param separate_output_tables (Default: FALSE) Create two output tables,
<output_table>_train and <output_table>_test.
Otherwise one output table is created with
and additional column 'split' which takes the
value 0 for test and 1 for training.
"""
with MinWarning("warning"):
if test_proportion is None:
test_proportion = 1 - train_proportion
validate_strs(source_table, output_table, train_proportion, test_proportion,
split_quoted_delimited_str(grouping_cols), target_cols,
with_replacement)
grouping_cols = _get_sql_string(grouping_cols)
target_cols = _get_sql_string(target_cols)
with_replacement = with_replacement or "False"
strat_query = """
SELECT {schema_madlib}.stratified_sample(
'{strat_source_table}',
'{strat_out_table}',
'{strat_proportion}',
{strat_grouping_cols},
{strat_target_cols},
{strat_with_replacement}
)
"""
strat_out_table = unique_string()
q = strat_query.format(
schema_madlib=schema_madlib,
strat_source_table=source_table,
strat_out_table=strat_out_table,
strat_proportion=train_proportion + test_proportion,
strat_grouping_cols=grouping_cols,
strat_with_replacement=with_replacement,
strat_target_cols=target_cols
)
plpy.execute(q)
test_table = add_postfix(output_table, "_test")
train_table = add_postfix(output_table, "_train")
if not separate_output_tables:
test_table = unique_string()
train_table = unique_string()
test_query = strat_query.format(
schema_madlib=schema_madlib,
strat_source_table=strat_out_table,
strat_out_table=test_table,
strat_proportion=(test_proportion /
(train_proportion + test_proportion)),
strat_grouping_cols=grouping_cols,
strat_with_replacement=False,
strat_target_cols=target_cols
)
plpy.execute(test_query)
train_query = """
CREATE TABLE {train_table} AS
SELECT * FROM {strat_out_table}
EXCEPT ALL
SELECT * FROM {test_table}
""".format(train_table=train_table,
strat_out_table=strat_out_table,
test_table=test_table)
plpy.execute(train_query)
clean_up_tables = [strat_out_table]
if not separate_output_tables:
union_query = """
CREATE TABLE {output_table} AS
SELECT *,0 AS split FROM {test_table}
UNION ALL
SELECT *,1 AS split FROM {train_table}
""".format(output_table=output_table,
test_table=test_table,
train_table=train_table)
plpy.execute(union_query)
clean_up_tables += [train_table, test_table]
clean_up_query = """
DROP TABLE IF EXISTS {clean_up_tables}
""".format(clean_up_tables=",".join(clean_up_tables))
plpy.execute(clean_up_query)
return
def validate_strs(source_table, output_table, train_proportion, test_proportion, glist, target_cols, with_replacement):
_assert(output_table and output_table.strip().lower() not in ('null', ''),
"Sample: Invalid output table name {output_table}!".format(**locals()))
_assert(not table_exists(output_table),
"Sample: Output table already exists!".format(**locals()))
_assert(source_table and source_table.strip().lower() not in ('null', ''),
"Sample: Invalid Source table name!".format(**locals()))
_assert(table_exists(source_table),
"Sample: Source table ({source_table}) is missing!".format(**locals()))
_assert(not table_is_empty(source_table),
"Sample: Source table ({source_table}) is empty!".format(**locals()))
for proportion in [train_proportion, test_proportion]:
_assert(proportion > 0 and proportion < 1,
"Sample: Proportions aren't in the range (0,1)!")
if not with_replacement:
_assert(train_proportion + test_proportion <= 1,
"Sample: Proportions add up to greater than 1!")
if glist is not None:
_assert(columns_exist_in_table(source_table, glist),
("""Sample: Not all columns from {glist} are present in source""" +
""" table ({source_table}).""").format(**locals()))
if not (target_cols is None or target_cols is '*'):
tlist = split_quoted_delimited_str(target_cols)
_assert(columns_exist_in_table(source_table, tlist),
("""Sample: Not all columns from {target_cols} are present in""" +
""" edge table ({source_table})""").format(**locals()))
return
def train_test_split_help(schema_madlib, message, **kwargs):
"""
Help function for train_test_split
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Given a table, train_test_split returns a random sample of the
table for testing and training. It is possible to use with or without
replacement sampling methods, specify a set of target columns, and a
set of grouping columns, in which case, stratified sampling will be
performed.
For more details on function usage:
SELECT {schema_madlib}.train_test_split('usage');
"""
elif message.lower() in ['usage', 'help', '?']:
help_string = """
Given a table, test train split returns a proportion of records for
each group (strata). It is possible to use with or without replacement
sampling methods, specify a set of target columns, and assume the
whole table is a single strata.
----------------------------------------------------------------------------
USAGE
----------------------------------------------------------------------------
SELECT {schema_madlib}.train_test_split(
source_table TEXT, -- Name of the table containing the input data.
output_table TEXT, -- Output table name.
train_proportion FLOAT8, -- The ratio of train sample size to the
-- number of records.
test_proportion FLOAT8, -- The ratio of test sample size to the
-- number of records.
grouping_cols TEXT -- (Default: NULL) The columns to distinguish
-- each strata.
target_cols TEXT, -- (Default: NULL) The columns to include in
-- the output.
with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
separate_output_tables
BOOLEAN -- (Default: FALSE) Separate the output table
-- into $output_table$_train and
-- $output_table$_test, otherwise, the split
-- column in output_table will identify 1 for
-- train set and 0 for test set.
If grouping_cols is NULL, the whole table is treated as a single group and
sampled accordingly.
If target_cols is NULL or '*', all of the columns will be included in the
output table.
If with_replacement is TRUE, each sample is independent (the same row may
be selected in the sample set more than once). Else (if with_replacement
is FALSE), a row can be selected at most once.
);
"""
else:
help_string = "No such option. Use {schema_madlib}.train_test_split()"
return help_string.format(schema_madlib=schema_madlib)