blob: a81b78d29e8f9c8678f53f9cd9fc931f6395168e [file] [log] [blame]
# ----------------------------------------------------------------------
# Linear regression
# ----------------------------------------------------------------------
import plpy
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import unique_string
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.utilities import add_postfix
from utilities.utilities import _assert
from utilities.control import MinWarning
# ----------------------------------------------------------------------
def linregr_train(schema_madlib, source_table, out_table,
dependent_varname, independent_varname, grouping_cols,
heteroskedasticity_option, **kwargs):
with MinWarning('warning'):
_validate_args(schema_madlib, source_table, out_table,
dependent_varname, independent_varname,
grouping_cols, heteroskedasticity_option)
group_str = '' if grouping_cols is None else 'GROUP BY %s' % grouping_cols
group_str_sel = '' if grouping_cols is None else grouping_cols + ','
join_str = ',' if grouping_cols is None else 'JOIN'
using_str = '' if grouping_cols is None else 'USING (%s)' % grouping_cols
# Run linear regression
temp_lin_rst = unique_string()
plpy.execute(
"""
DROP TABLE IF EXISTS {temp_lin_rst};
CREATE TEMP TABLE {temp_lin_rst} AS
SELECT
{group_str_sel}
({schema_madlib}.linregr(
{dependent_varname},
{independent_varname})).*,
count(*) AS num_rows
FROM
{source_table}
{group_str}
""".format(schema_madlib=schema_madlib,
temp_lin_rst=temp_lin_rst,
group_str=group_str,
group_str_sel=group_str_sel,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
source_table=source_table))
# Run heteroskedasticity test
if heteroskedasticity_option:
temp_hsk_rst = unique_string()
plpy.execute(
"""
DROP TABLE IF EXISTS {temp_hsk_rst};
CREATE TEMP TABLE {temp_hsk_rst} AS
SELECT
{group_str_sel}
{schema_madlib}.heteroskedasticity_test_linregr(
{dependent_varname},
{independent_varname},
{temp_lin_rst}.coef) AS hsk_rst
FROM
{source_table} {join_str} {temp_lin_rst} {using_str}
{group_str}
""".format(schema_madlib=schema_madlib,
temp_hsk_rst=temp_hsk_rst,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
group_str_sel=group_str_sel, group_str=group_str,
join_str=join_str, using_str=using_str,
source_table=source_table, temp_lin_rst=temp_lin_rst))
# Output the results
join_str = ''
using_str = ''
if heteroskedasticity_option:
if grouping_cols is not None:
join_str = 'JOIN %s AS hsk' % temp_hsk_rst
using_str = 'USING (%s)' % (grouping_cols)
else:
join_str = ', %s AS hsk' % temp_hsk_rst
bp_stats = '(hsk.hsk_rst).bp_stats,' if heteroskedasticity_option else ''
bp_p_value = '(hsk.hsk_rst).bp_p_value,' if heteroskedasticity_option else ''
plpy.execute(
"""
CREATE TABLE {out_table} AS
SELECT
{group_str_sel}
coef,
r2,
std_err,
t_stats,
p_values,
condition_no,
{bp_stats}
{bp_p_value}
CASE WHEN num_processed IS NULL
THEN 0
ELSE num_processed
END AS num_rows_processed,
CASE WHEN num_processed IS NULL
THEN lin.num_rows
ELSE lin.num_rows - num_processed
END AS num_missing_rows_skipped,
vcov as variance_covariance
FROM
{temp_lin_rst} AS lin {join_str} {using_str}
""".format(out_table=out_table, group_str_sel=group_str_sel,
bp_stats=bp_stats, bp_p_value=bp_p_value,
temp_lin_rst=temp_lin_rst, join_str=join_str, using_str=using_str))
num_rows = plpy.execute(
"""
select
sum(num_rows_processed) as num_rows_processed,
sum(num_missing_rows_skipped) as num_rows_skipped
from {tbl_output}
""".format(tbl_output=out_table))[0]
if num_rows['num_rows_processed'] is None:
num_rows['num_rows_processed'] = "NULL"
num_rows['num_rows_skipped'] = "NULL"
out_table_summary = add_postfix(out_table, "_summary")
plpy.execute(
"""
create table {out_table_summary} as
select
'linregr'::varchar as method
, '{source_table}'::varchar as source_table
, '{out_table}'::varchar as out_table
, '{dependent_varname}'::varchar as dependent_varname
, '{independent_varname}'::varchar as independent_varname
, {num_rows_processed}::integer as num_rows_processed
, {num_rows_skipped}::integer as num_missing_rows_skipped
, {grouping_col}::text as grouping_col
""".format(out_table=out_table, source_table=source_table,
out_table_summary=out_table_summary,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
grouping_col="'" + grouping_cols + "'" if grouping_cols else "NULL",
**num_rows))
return None
# ----------------------------------------------------------------------
def _validate_args(schema_madlib, source_table, out_table, dependent_varname,
independent_varname, grouping_cols, heteroskedasticity_option):
"""
@brief validate the arguments
"""
_assert(source_table is not None and
source_table.strip().lower() not in ('null', ''),
"Linregr error: Invalid data table name!")
_assert(table_exists(source_table),
"Linregr error: Data table does not exist!")
_assert(not table_is_empty(source_table),
"Linregr error: Data table is empty!")
_assert(out_table is not None and
out_table.strip().lower() not in ('null', ''),
"Linregr error: Invalid output table name!")
_assert(not table_exists(out_table, only_first_schema=True),
"Output table name already exists. Drop the table before calling the function.")
_assert(dependent_varname is not None and
dependent_varname.strip().lower() not in ('null', ''),
"Linregr error: Invalid dependent column name!")
_assert(independent_varname is not None and
independent_varname.strip().lower() not in ('null', ''),
"Linregr error: Invalid independent column name!")
# Note: We do not further validate dependent_varname/independent_varname
# because we allow flexible forms (e.g. valid expression), but we need
# an informative error messages when they are not valid
_assert(heteroskedasticity_option is not None and
heteroskedasticity_option in (True, False),
"Linregr error: Invalid heteroskedasticity_option")
if grouping_cols is not None:
_assert(grouping_cols != '',
"Linregr error: Invalid grouping columns name!")
grouping_list = _string_to_array_with_quotes(grouping_cols)
_assert(columns_exist_in_table(
source_table, grouping_list, schema_madlib),
"Linregr error: Grouping column does not exist!")
predefined = set(('coef', 'r2', 'std_err', 't_stats',
'p_values', 'condition_no',
'num_processed', 'num_missing_rows_skipped',
'variance_covariance'))
if heteroskedasticity_option is not None:
predefined.update(frozenset(('bp_stats', 'bp_p_value')))
intersect = frozenset(grouping_list).intersection(predefined)
_assert(not intersect,
"Linregr error: Conflicting grouping column name.\n"
"Predefined name(s) {0} are not allowed!".
format(', '.join(intersect)))
# ------------------------------------------------------------------------------
# -- Online help function ------------------------------------------------------
# ------------------------------------------------------------------------------
def linregr_help_message(schema_madlib, message, **kwargs):
""" Help message for Linear Regression
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Ordinary Least Squares Regression, also called Linear Regression, is a
statistical model used to fit linear models.
It models a linear relationship of a scalar dependent variable y to one
or more explanatory independent variables x to build a
model of coefficients.
For more details on function usage:
SELECT {schema_madlib}.linregr_train('usage')
For an example on using the function:
SELECT {schema_madlib}.linregr_train('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.linregr_train(
source_table, -- name of input table
out_table, -- name of output table
dependent_varname, -- name of dependent variable
independent_varname, -- name of independent variables
grouping_cols, -- names of columns to group-by
heteroskedasticity_option -- perform heteroskedasticity test?
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...>, -- Grouping columns used during training
'coef' DOUBLE PRECISION[], -- Vector of coefficients
'r2' DOUBLE PRECISION, -- R-squared coefficient
'std_err' DOUBLE PRECISION[], -- Standard errors of coefficients
't_stats' DOUBLE PRECISION[], -- t-stats of the coefficients
'p_values' DOUBLE PRECISION[], -- p-values of the coefficients
'condition_no' INTEGER, -- The condition number of the covariance matrix.
'bp_stats' DOUBLE PRECISION, -- The Breush-Pagan statistic of heteroskedacity.
(if heteroskedasticity_option=TRUE)
'bp_p_value' DOUBLE PRECISION, -- The Breush-Pagan calculated p-value.
(if heteroskedasticity_option=TRUE)
'num_rows_processed' INTEGER, -- Number of rows that are actually used in each group
'num_missing_rows_skipped' INTEGER -- Number of rows that have NULL and are skipped in each group
A summary table named <out_table>_summary is also created at the same time, which has:
'source_table' VARCHAR, -- the data source table name
'out_table' VARCHAR, -- the output table name
'dependent_varname' VARCHAR, -- the dependent variable
'independent_varname' VARCHAR, -- the independent variable
'num_rows_processed' INTEGER, -- total number of rows that are used
'num_missing_rows_skipped' INTEGER -- total number of rows that are skipped because of NULL values
"""
elif message in ['example', 'examples']:
help_string = """
CREATE TABLE houses (id INT, tax INT,
bedroom INT, bath FLOAT,
price INT, size INT, lot INT);
COPY houses FROM STDIN WITH DELIMITER '|';
1 | 590 | 2 | 1 | 50000 | 770 | 22100
2 | 1050 | 3 | 2 | 85000 | 1410 | 12000
3 | 20 | 3 | 1 | 22500 | 1060 | 3500
4 | 870 | 2 | 2 | 90000 | 1300 | 17500
5 | 1320 | 3 | 2 | 133000 | 1500 | 30000
6 | 1350 | 2 | 1 | 90500 | 820 | 25700
7 | 2790 | 3 | 2.5 | 260000 | 2130 | 25000
8 | 680 | 2 | 1 | 142500 | 1170 | 22000
9 | 1840 | 3 | 2 | 160000 | 1500 | 19000
10 | 3680 | 4 | 2 | 240000 | 2790 | 20000
11 | 1660 | 3 | 1 | 87000 | 1030 | 17500
12 | 1620 | 3 | 2 | 118600 | 1250 | 20000
13 | 3100 | 3 | 2 | 140000 | 1760 | 38000
14 | 2070 | 2 | 3 | 148000 | 1550 | 14000
15 | 650 | 3 | 1.5 | 65000 | 1450 | 12000
\.
-- Train a regression model. First, single regression for all data.
SELECT {schema_madlib}.linregr_train( 'houses',
'houses_linregr',
'price',
'ARRAY[1, tax, bath, size]'
);
-- Generate three output models, one for each value of "bedroom".
SELECT {schema_madlib}.linregr_train('houses',
'houses_linregr_bedroom',
'price',
'ARRAY[1, tax, bath, size]',
'bedroom'
);
-- Examine the resulting models.
SELECT * FROM houses_linregr;
SELECT * FROM houses_linregr_bedroom;
"""
else:
help_string = "No such option. Use {schema_madlib}.linregr_train()"
return help_string.format(schema_madlib=schema_madlib)
def linregr_predict_help_message(schema_madlib, message, **kwargs):
""" Help message for Prediction in Linear Regression
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Prediction Function for Ordinary Least Squares Regression (Linear
Regression), is simple wrapper for dot product function
{schema_madlib}.array_dot.
For more details on function usage:
SELECT {schema_madlib}.linregr_predict('usage')
For an example on using the function:
SELECT {schema_madlib}.linregr_predict('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
linregr_predict(
coef, -- DOUBLE PRECISION[], Coefficient of logistic regression model
col_ind_var -- DOUBLE PRECISION[], Values for the independent variables
)
The lengths of 'coef' array and 'col_ind_var' array above should be equal. For
a small example on using the function:
SELECT {schema_madlib}.linregr_predict('example')
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output of the function is a DOUBLE PRECISION value representing the
predicted dependent variable value.
"""
elif message in ['example', 'examples']:
help_string = """
-- The tables below are obtained from the example in 'linregr_train'.
-- Details can be found by running "SELECT linregr_train('examples');"
SELECT id,
madlib.linregr_predict(coef, ARRAY[1, tax, bath, size]) as pred_value
FROM houses h, houses_linregr m
ORDER BY id;
"""
else:
help_string = "No such option. Use {schema_madlib}.linregr_predict()"
return help_string.format(schema_madlib=schema_madlib)