blob: 364b380e36f832bed68c12fe122062d0bce23b28 [file] [log] [blame]
# ----------------------------------------------------------------------
# Linear regression
# ----------------------------------------------------------------------
import plpy
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import unique_string
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.utilities import add_postfix
from utilities.utilities import _assert
from utilities.control import MinWarning
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import strip_end_quotes
from utilities.utilities import _string_to_array
from utilities.validate_args import get_expr_type
from utilities.utilities import _string_to_sql_array
from utilities.validate_args import quote_ident
# ----------------------------------------------------------------------
def linregr_train(schema_madlib, source_table, out_table,
dependent_varname, independent_varname, grouping_cols,
heteroskedasticity_option, **kwargs):
"""
Args:
@param source_table, -- name of input table
@param out_table, -- name of output table
@param dependent_varname, -- name of dependent variable
@param independent_varname, -- name of independent variables
@param grouping_cols, -- names of columns to group-by
@param heteroskedasticity_option -- perform heteroskedasticity test?
"""
with MinWarning('warning'):
_validate_args(schema_madlib, source_table, out_table,
dependent_varname, independent_varname,
grouping_cols, heteroskedasticity_option)
group_col_list = ''if grouping_cols is None else _string_to_array_with_quotes(
grouping_cols)
group_str = '' if grouping_cols is None else 'GROUP BY %s' % get_table_qualified_col_str(source_table,
group_col_list)
# For json expressions like data->>'ID', this will quote them to
# "data->>'ID'"
cols_wo_quotes = ''if grouping_cols is None else ' '.join(
strip_end_quotes(col).format(**locals()) for col in group_col_list)
group_cols_w_quotes = ''if grouping_cols is None else ' ,'.join(
" \"" + strip_end_quotes(col) + "\" ".format(**locals()) for col in group_col_list)
# For json expressions like data->>'ID' it will alias the columns to
# "data->>'ID'"
group_str_sel = ''if grouping_cols is None else ' , '.join("{source_table}.{col} as \"".format(source_table=source_table, col=col) + strip_end_quotes(col) + "\" "
for col in group_col_list) + ","
join_str = ',' if grouping_cols is None else 'JOIN'
using_str = '' if grouping_cols is None else 'USING (%s)' % grouping_cols
# Run linear regression
temp_lin_rst = unique_string()
plpy.execute(
"""
DROP TABLE IF EXISTS {temp_lin_rst};
CREATE TEMP TABLE {temp_lin_rst} AS
SELECT
{group_str_sel}
({schema_madlib}.linregr(
{dependent_varname},
{independent_varname})).*,
count(*) AS num_rows
FROM
{source_table}
{group_str}
""".format(schema_madlib=schema_madlib,
temp_lin_rst=temp_lin_rst,
group_str=group_str,
group_str_sel=group_str_sel,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
source_table=source_table))
# USING Clause does not support expressions. So modifying to the regular
# join clause instead.
join_clause = '' if grouping_cols is None else " ON " + ' AND '.join("{source_table}.{col} = {temp_lin_rst}.\""
.format(source_table=source_table, col=col,
temp_lin_rst=temp_lin_rst) + strip_end_quotes(col) + "\""
for col in group_col_list)
# Run heteroskedasticity test
if heteroskedasticity_option:
temp_hsk_rst = unique_string()
plpy.execute(
"""
DROP TABLE IF EXISTS {temp_hsk_rst};
CREATE TEMP TABLE {temp_hsk_rst} AS
SELECT
{group_str_sel}
{schema_madlib}.heteroskedasticity_test_linregr(
{dependent_varname},
{independent_varname},
{temp_lin_rst}.coef) AS hsk_rst
FROM
{source_table} {join_str} {temp_lin_rst} {join_clause}
{group_str}
""".format(schema_madlib=schema_madlib,
temp_hsk_rst=temp_hsk_rst,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
group_str_sel=group_str_sel, group_str=group_str,
join_str=join_str, source_table=source_table,
temp_lin_rst=temp_lin_rst, join_clause=join_clause))
# Output the results
join_str = ''
using_str = ''
if heteroskedasticity_option:
if grouping_cols is not None:
join_str = 'JOIN %s AS hsk' % temp_hsk_rst
using_str = 'USING (%s)' % group_cols_w_quotes
else:
join_str = ', %s AS hsk' % temp_hsk_rst
bp_stats = '(hsk.hsk_rst).bp_stats,' if heteroskedasticity_option else ''
bp_p_value = '(hsk.hsk_rst).bp_p_value,' if heteroskedasticity_option else ''
group_cols_w_quotes = '' if grouping_cols is None else group_cols_w_quotes + ","
plpy.execute(
"""
CREATE TABLE {out_table} AS
SELECT
{group_cols_w_quotes}
coef,
r2,
std_err,
t_stats,
p_values,
condition_no,
{bp_stats}
{bp_p_value}
CASE WHEN num_processed IS NULL
THEN 0
ELSE num_processed
END AS num_rows_processed,
CASE WHEN num_processed IS NULL
THEN lin.num_rows
ELSE lin.num_rows - num_processed
END AS num_missing_rows_skipped,
vcov as variance_covariance
FROM
{temp_lin_rst} AS lin {join_str} {using_str}
""".format(out_table=out_table,
bp_stats=bp_stats, bp_p_value=bp_p_value,
temp_lin_rst=temp_lin_rst, join_str=join_str, using_str=using_str,
group_cols_w_quotes=group_cols_w_quotes))
num_rows = plpy.execute(
"""
select
sum(num_rows_processed) as num_rows_processed,
sum(num_missing_rows_skipped) as num_rows_skipped
from {tbl_output}
""".format(tbl_output=out_table))[0]
if num_rows['num_rows_processed'] is None:
num_rows['num_rows_processed'] = "NULL"
num_rows['num_rows_skipped'] = "NULL"
out_table_summary = add_postfix(out_table, "_summary")
plpy.execute(
"""
create table {out_table_summary} as
select
'linregr'::varchar as method
, '{source_table}'::varchar as source_table
, '{out_table}'::varchar as out_table
, $__madlib__${dependent_varname}$__madlib__$::varchar as dependent_varname
, $__madlib__${independent_varname}$__madlib__$::varchar as independent_varname
, {num_rows_processed}::BIGINT as num_rows_processed
, {num_rows_skipped}::integer as num_missing_rows_skipped
, {grouping_col}::text as grouping_col
""".format(out_table=out_table, source_table=source_table,
out_table_summary=out_table_summary,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
grouping_col="$__madlib__$"+grouping_cols+"$__madlib__$" if grouping_cols else "NULL",
**num_rows))
return None
# ----------------------------------------------------------------------
def _validate_args(schema_madlib, source_table, out_table, dependent_varname,
independent_varname, grouping_cols, heteroskedasticity_option):
"""
@brief validate the arguments
"""
_assert(source_table is not None and
source_table.strip().lower() not in ('null', ''),
"Linregr error: Invalid data table name!")
_assert(table_exists(source_table),
"Linregr error: Data table does not exist!")
_assert(not table_is_empty(source_table),
"Linregr error: Data table is empty!")
_assert(out_table is not None and
out_table.strip().lower() not in ('null', ''),
"Linregr error: Invalid output table name!")
_assert(not table_exists(out_table, only_first_schema=True),
"Output table name already exists. Drop the table before calling the function.")
_assert(dependent_varname is not None and
dependent_varname.strip().lower() not in ('null', ''),
"Linregr error: Invalid dependent column name!")
_assert(independent_varname is not None and
independent_varname.strip().lower() not in ('null', ''),
"Linregr error: Invalid independent column name!")
# Note: We do not further validate dependent_varname/independent_varname
# because we allow flexible forms (e.g. valid expression), but we need
# an informative error messages when they are not valid
_assert(heteroskedasticity_option is not None and
heteroskedasticity_option in (True, False),
"Linregr error: Invalid heteroskedasticity_option")
if grouping_cols is not None:
_assert(grouping_cols != '',
"Linregr error: Invalid grouping columns name!")
# grouping columns can be a valid expression as well, for eg.
# a json expression (data->>'id'), so commenting this part.
grouping_list = _string_to_array_with_quotes(grouping_cols)
predefined = set(('coef', 'r2', 'std_err', 't_stats',
'p_values', 'condition_no',
'num_processed', 'num_missing_rows_skipped',
'variance_covariance'))
if heteroskedasticity_option is not None:
predefined.update(frozenset(('bp_stats', 'bp_p_value')))
intersect = frozenset(grouping_list).intersection(predefined)
_assert(not intersect,
"Linregr error: Conflicting grouping column name.\n"
"Predefined name(s) {0} are not allowed!".
format(', '.join(intersect)))
# ------------------------------------------------------------------------------
# -- Online help function ------------------------------------------------------
# ------------------------------------------------------------------------------
def linregr_help_message(schema_madlib, message, **kwargs):
""" Help message for Linear Regression
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Ordinary Least Squares Regression, also called Linear Regression, is a
statistical model used to fit linear models.
It models a linear relationship of a scalar dependent variable y to one
or more explanatory independent variables x to build a
model of coefficients.
For more details on function usage:
SELECT {schema_madlib}.linregr_train('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.linregr_train(
source_table, -- name of input table
out_table, -- name of output table
dependent_varname, -- name of dependent variable
independent_varname, -- name of independent variables
grouping_cols, -- names of columns to group-by
heteroskedasticity_option -- perform heteroskedasticity test?
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...>, -- Grouping columns used during training
'coef' DOUBLE PRECISION[], -- Vector of coefficients
'r2' DOUBLE PRECISION, -- R-squared coefficient
'std_err' DOUBLE PRECISION[], -- Standard errors of coefficients
't_stats' DOUBLE PRECISION[], -- t-stats of the coefficients
'p_values' DOUBLE PRECISION[], -- p-values of the coefficients
'condition_no' INTEGER, -- The condition number of the covariance matrix.
'bp_stats' DOUBLE PRECISION, -- The Breush-Pagan statistic of heteroskedacity.
(if heteroskedasticity_option=TRUE)
'bp_p_value' DOUBLE PRECISION, -- The Breush-Pagan calculated p-value.
(if heteroskedasticity_option=TRUE)
'num_rows_processed' INTEGER, -- Number of rows that are actually used in each group
'num_missing_rows_skipped' INTEGER -- Number of rows that have NULL and are skipped in each group
A summary table named <out_table>_summary is also created at the same time, which has:
'source_table' VARCHAR, -- the data source table name
'out_table' VARCHAR, -- the output table name
'dependent_varname' VARCHAR, -- the dependent variable
'independent_varname' VARCHAR, -- the independent variable
'num_rows_processed' INTEGER, -- total number of rows that are used
'num_missing_rows_skipped' INTEGER -- total number of rows that are skipped because of NULL values
"""
else:
help_string = "No such option. Use {schema_madlib}.linregr_train()"
return help_string.format(schema_madlib=schema_madlib)
def linregr_predict_help_message(schema_madlib, message, **kwargs):
""" Help message for Prediction in Linear Regression
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Prediction Function for Ordinary Least Squares Regression (Linear
Regression), is simple wrapper for dot product function
{schema_madlib}.array_dot.
For more details on function usage:
SELECT {schema_madlib}.linregr_predict('usage')
For an example on using the function:
SELECT {schema_madlib}.linregr_predict('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
linregr_predict(
coef, -- DOUBLE PRECISION[], Coefficient of logistic regression model
col_ind_var -- DOUBLE PRECISION[], Values for the independent variables
)
The lengths of 'coef' array and 'col_ind_var' array above should be equal. For
a small example on using the function:
SELECT {schema_madlib}.linregr_predict('example')
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output of the function is a DOUBLE PRECISION value representing the
predicted dependent variable value.
"""
elif message in ['example', 'examples']:
help_string = """
-- The tables below are obtained from the example in 'linregr_train'.
-- Details can be found by running "SELECT linregr_train('examples');"
SELECT id,
madlib.linregr_predict(coef, ARRAY[1, tax, bath, size]) as pred_value
FROM houses h, houses_linregr m
ORDER BY id;
"""
else:
help_string = "No such option. Use {schema_madlib}.linregr_predict()"
return help_string.format(schema_madlib=schema_madlib)