blob: 9408739e10fe1658f33785b7cf6523c1d9cdfe5d [file] [log] [blame]
# coding=utf-8
"""
@file robust_linear.py_in
@namespace robust
@brief Robust variance: Common functions
"""
import plpy
from utilities.utilities import _assert
from utilities.utilities import unique_string
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import add_postfix
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
from utilities.validate_args import columns_exist_in_table
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
from utilities.utilities import __mad_version
version_wrapper = __mad_version()
string_to_array = version_wrapper.select_vecfunc()
array_to_string = version_wrapper.select_vec_return()
def _robust_linregr_validate(schema_madlib, source_table, output_table,
dependent_varname, independent_varname,
grouping_cols, verbose_mode, **kwargs):
_assert(source_table and
source_table.strip().lower() not in ('null', ''),
"Robust Variance error: Invalid data table name!")
_assert(table_exists(source_table),
"Robust Variance error: Data table does not exist!")
_assert(not table_is_empty(source_table),
"Robust Variance error: Data table is empty!")
_assert(output_table and
output_table.strip().lower() not in ('null', ''),
"Robust Variance error: Invalid output table name!")
_assert(not table_exists(output_table, only_first_schema=True),
"Robust Variance error: Output table already exists!")
_assert(not table_exists(output_table + '_summary', only_first_schema=True),
"Robust Variance error: Output summary table already exists!")
_assert(dependent_varname and
dependent_varname.strip().lower() not in ('null', ''),
"Robust Variance error: Invalid dependent column name!")
_assert(independent_varname and
independent_varname.strip().lower() not in ('null', ''),
"Robust Variance error: Invalid independent column name!")
if grouping_cols:
_assert(grouping_cols.strip().lower() not in ('null', ''),
"Robust Variance error: Invalid grouping columns name!")
_assert(columns_exist_in_table(
source_table, _string_to_array_with_quotes(grouping_cols), schema_madlib),
"Robust Variance error: Grouping column does not exist!")
_assert(verbose_mode is not None and isinstance(verbose_mode, bool),
"Robust Variance error: The verbose_mode should be of boolean type!")
# -------------------------------------------------------------------------
def robust_linregr_help(schema_madlib, message, **kwargs):
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Calculate Huber-White robust statistics for linear regression
For more details on function usage:
SELECT {schema_madlib}.robust_variance_linregr('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.robust_variance_linregr(
'source_table', -- Name of data table
'output_table', -- Name of result table
'dependent_varname', -- Name of column for dependent variables
'independent_varname', -- Name of column for independent variables
(can be any SQL expression that evaluates to an array)
'group_cols', -- [OPTIONAL] Comma separated string with columns to group by.
-- Default is NULL.
'verbose_mode' -- [OPTIONAL] Should warning messages be printed on screen.
-- Default is FALSE.
);
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The output table (''output_table'' above) has the following columns:
'coef' DOUBLE PRECISION[], -- Coefficients of regression
'std_err' DOUBLE PRECISION[], -- Huber-White standard errors
'stats' DOUBLE PRECISION[], -- T-stats of the standard errors
'p_values' DOUBLE PRECISION[] -- p-values of the standard errors
The output summary table is the same as linregr_train(), see also:
SELECT linregr_train('usage');
"""
else:
help_string = "No such option. Use {schema_madlib}.robust_variance_linregr()"
return help_string.format(schema_madlib=schema_madlib)
# -------------------------------------------------------------------------
def robust_variance_linregr(
schema_madlib, source_table, out_table, dependent_varname,
independent_varname, grouping_cols=None, verbose_mode=None, **kwargs):
"""
@brief A wrapper function for the robust_variance_mlogregr.
@param source_table string, name of the input table
@param out_table string, name of the output table to be created
@param dependent_varname: string, Column containing the dependent variable
@param independent_varname string, Column containing the array of independent variables
@param grouping_cols string, Set of columns to group by.
To include an intercept in the model, set one coordinate in the
<tt>independentVariables</tt> array to 1.
Returns:
None
"""
# Reset the message level to avoid random messages
old_msg_level = plpy.execute("""
SELECT setting
FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose_mode:
plpy.execute('SET client_min_messages TO warning')
else:
plpy.execute('SET client_min_messages TO error')
_robust_linregr_validate(schema_madlib, source_table, out_table,
dependent_varname, independent_varname,
grouping_cols, verbose_mode)
group_str = '' if grouping_cols is None else 'GROUP BY %s' % grouping_cols
group_str_sel = '' if grouping_cols is None else grouping_cols + ','
join_str = ',' if grouping_cols is None else 'JOIN'
using_str = '' if grouping_cols is None else 'USING (%s)' % grouping_cols
group_col_str = 'NULL' if grouping_cols is None else "'" + grouping_cols + "'"
lr_out_table = "pg_temp." + unique_string()
rb_model = unique_string()
# Run linear regression
plpy.execute("""
SELECT {schema_madlib}.linregr_train(
'{source_table}', '{lr_out_table}',
'{dependent_varname}', '{independent_varname}', {group_col_str})
""".format(schema_madlib=schema_madlib, source_table=source_table,
lr_out_table=lr_out_table,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
group_col_str=group_col_str))
# Create output summary table
out_table_summary = add_postfix(out_table, "_summary")
lr_out_table_summary = add_postfix(lr_out_table, "_summary")
plpy.execute("""
CREATE TABLE {out_table_summary} AS
SELECT
'{source_table}' AS source_table,
'{out_table}' AS output_table,
'{dependent_varname}' AS dependent_varname,
'{independent_varname}' AS independent_varname,
num_rows_processed, num_missing_rows_skipped
FROM
{lr_out_table_summary}
""".format(source_table=source_table, out_table=out_table,
out_table_summary=out_table_summary,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
lr_out_table_summary=lr_out_table_summary))
# Run robust linear regression
plpy.execute("""
CREATE TABLE {out_table} AS
SELECT
{group_str_sel}
({rb_model}).coef, ({rb_model}).std_err,
({rb_model}).t_stats, ({rb_model}).p_values
FROM
(
SELECT
{group_str_sel}
{schema_madlib}.robust_linregr(
{dependent_varname},
{independent_varname},
{lr_out_table}.coef) AS {rb_model}
FROM
{source_table} {join_str} {lr_out_table} {using_str}
{group_str}
) t1
""".format(schema_madlib=schema_madlib,
source_table=source_table, out_table=out_table,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
group_str_sel=group_str_sel, group_str=group_str,
join_str=join_str, using_str=using_str,
lr_out_table=lr_out_table, rb_model=rb_model))
plpy.execute('DROP TABLE IF EXISTS ' + lr_out_table)
plpy.execute('DROP TABLE IF EXISTS ' + lr_out_table + '_summary')
plpy.execute("SET client_min_messages TO %s" % old_msg_level)