| # coding=utf-8 |
| |
| |
| """ |
| @file robust_linear.py_in |
| |
| @namespace robust |
| |
| @brief Robust variance: Common functions |
| """ |
| import plpy |
| from utilities.utilities import _assert |
| from utilities.utilities import unique_string |
| from utilities.utilities import _string_to_array |
| from utilities.utilities import _string_to_array_with_quotes |
| from utilities.utilities import add_postfix |
| |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import columns_exist_in_table |
| |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| from utilities.utilities import __mad_version |
| version_wrapper = __mad_version() |
| string_to_array = version_wrapper.select_vecfunc() |
| array_to_string = version_wrapper.select_vec_return() |
| |
| |
| def _robust_linregr_validate(schema_madlib, source_table, output_table, |
| dependent_varname, independent_varname, |
| grouping_cols, verbose_mode, **kwargs): |
| _assert(source_table and |
| source_table.strip().lower() not in ('null', ''), |
| "Robust Variance error: Invalid data table name!") |
| _assert(table_exists(source_table), |
| "Robust Variance error: Data table does not exist!") |
| _assert(not table_is_empty(source_table), |
| "Robust Variance error: Data table is empty!") |
| _assert(output_table and |
| output_table.strip().lower() not in ('null', ''), |
| "Robust Variance error: Invalid output table name!") |
| _assert(not table_exists(output_table, only_first_schema=True), |
| "Robust Variance error: Output table already exists!") |
| _assert(not table_exists(output_table + '_summary', only_first_schema=True), |
| "Robust Variance error: Output summary table already exists!") |
| _assert(dependent_varname and |
| dependent_varname.strip().lower() not in ('null', ''), |
| "Robust Variance error: Invalid dependent column name!") |
| _assert(independent_varname and |
| independent_varname.strip().lower() not in ('null', ''), |
| "Robust Variance error: Invalid independent column name!") |
| if grouping_cols: |
| _assert(grouping_cols.strip().lower() not in ('null', ''), |
| "Robust Variance error: Invalid grouping columns name!") |
| _assert(columns_exist_in_table( |
| source_table, _string_to_array_with_quotes(grouping_cols), schema_madlib), |
| "Robust Variance error: Grouping column does not exist!") |
| _assert(verbose_mode is not None and isinstance(verbose_mode, bool), |
| "Robust Variance error: The verbose_mode should be of boolean type!") |
| # ------------------------------------------------------------------------- |
| |
| |
| def robust_linregr_help(schema_madlib, message, **kwargs): |
| if not message: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| SUMMARY |
| ----------------------------------------------------------------------- |
| Functionality: Calculate Huber-White robust statistics for linear regression |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.robust_variance_linregr('usage') |
| """ |
| elif message in ['usage', 'help', '?']: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| USAGE |
| ----------------------------------------------------------------------- |
| SELECT {schema_madlib}.robust_variance_linregr( |
| 'source_table', -- Name of data table |
| 'output_table', -- Name of result table |
| 'dependent_varname', -- Name of column for dependent variables |
| 'independent_varname', -- Name of column for independent variables |
| (can be any SQL expression that evaluates to an array) |
| 'group_cols', -- [OPTIONAL] Comma separated string with columns to group by. |
| -- Default is NULL. |
| 'verbose_mode' -- [OPTIONAL] Should warning messages be printed on screen. |
| -- Default is FALSE. |
| ); |
| ----------------------------------------------------------------------- |
| OUTUPT |
| ----------------------------------------------------------------------- |
| The output table (''output_table'' above) has the following columns: |
| 'coef' DOUBLE PRECISION[], -- Coefficients of regression |
| 'std_err' DOUBLE PRECISION[], -- Huber-White standard errors |
| 'stats' DOUBLE PRECISION[], -- T-stats of the standard errors |
| 'p_values' DOUBLE PRECISION[] -- p-values of the standard errors |
| |
| The output summary table is the same as linregr_train(), see also: |
| SELECT linregr_train('usage'); |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.robust_variance_linregr()" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------------------- |
| |
| |
| def robust_variance_linregr( |
| schema_madlib, source_table, out_table, dependent_varname, |
| independent_varname, grouping_cols=None, verbose_mode=None, **kwargs): |
| """ |
| @brief A wrapper function for the robust_variance_mlogregr. |
| |
| @param source_table string, name of the input table |
| @param out_table string, name of the output table to be created |
| @param dependent_varname: string, Column containing the dependent variable |
| @param independent_varname string, Column containing the array of independent variables |
| @param grouping_cols string, Set of columns to group by. |
| |
| To include an intercept in the model, set one coordinate in the |
| <tt>independentVariables</tt> array to 1. |
| |
| Returns: |
| None |
| """ |
| # Reset the message level to avoid random messages |
| old_msg_level = plpy.execute(""" |
| SELECT setting |
| FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| if verbose_mode: |
| plpy.execute('SET client_min_messages TO warning') |
| else: |
| plpy.execute('SET client_min_messages TO error') |
| |
| _robust_linregr_validate(schema_madlib, source_table, out_table, |
| dependent_varname, independent_varname, |
| grouping_cols, verbose_mode) |
| |
| group_str = '' if grouping_cols is None else 'GROUP BY %s' % grouping_cols |
| group_str_sel = '' if grouping_cols is None else grouping_cols + ',' |
| join_str = ',' if grouping_cols is None else 'JOIN' |
| using_str = '' if grouping_cols is None else 'USING (%s)' % grouping_cols |
| group_col_str = 'NULL' if grouping_cols is None else "'" + grouping_cols + "'" |
| |
| lr_out_table = "pg_temp." + unique_string() |
| rb_model = unique_string() |
| |
| # Run linear regression |
| plpy.execute(""" |
| SELECT {schema_madlib}.linregr_train( |
| '{source_table}', '{lr_out_table}', |
| '{dependent_varname}', '{independent_varname}', {group_col_str}) |
| """.format(schema_madlib=schema_madlib, source_table=source_table, |
| lr_out_table=lr_out_table, |
| dependent_varname=dependent_varname, |
| independent_varname=independent_varname, |
| group_col_str=group_col_str)) |
| |
| # Create output summary table |
| out_table_summary = add_postfix(out_table, "_summary") |
| lr_out_table_summary = add_postfix(lr_out_table, "_summary") |
| plpy.execute(""" |
| CREATE TABLE {out_table_summary} AS |
| SELECT |
| '{source_table}' AS source_table, |
| '{out_table}' AS output_table, |
| '{dependent_varname}' AS dependent_varname, |
| '{independent_varname}' AS independent_varname, |
| num_rows_processed, num_missing_rows_skipped |
| FROM |
| {lr_out_table_summary} |
| """.format(source_table=source_table, out_table=out_table, |
| out_table_summary=out_table_summary, |
| dependent_varname=dependent_varname, |
| independent_varname=independent_varname, |
| lr_out_table_summary=lr_out_table_summary)) |
| |
| # Run robust linear regression |
| plpy.execute(""" |
| CREATE TABLE {out_table} AS |
| SELECT |
| {group_str_sel} |
| ({rb_model}).coef, ({rb_model}).std_err, |
| ({rb_model}).t_stats, ({rb_model}).p_values |
| FROM |
| ( |
| SELECT |
| {group_str_sel} |
| {schema_madlib}.robust_linregr( |
| {dependent_varname}, |
| {independent_varname}, |
| {lr_out_table}.coef) AS {rb_model} |
| FROM |
| {source_table} {join_str} {lr_out_table} {using_str} |
| {group_str} |
| ) t1 |
| """.format(schema_madlib=schema_madlib, |
| source_table=source_table, out_table=out_table, |
| dependent_varname=dependent_varname, |
| independent_varname=independent_varname, |
| group_str_sel=group_str_sel, group_str=group_str, |
| join_str=join_str, using_str=using_str, |
| lr_out_table=lr_out_table, rb_model=rb_model)) |
| |
| plpy.execute('DROP TABLE IF EXISTS ' + lr_out_table) |
| plpy.execute('DROP TABLE IF EXISTS ' + lr_out_table + '_summary') |
| plpy.execute("SET client_min_messages TO %s" % old_msg_level) |