blob: 49dd9c8985b9adbeb0a26041b8a5d1d92dc70893 [file] [log] [blame]
# coding=utf-8
"""
@file glm.py_in
@brief Generalized Linear Models: Driver functions
@namespace glm
@brief Generalized Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import extract_keyvalue_params
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.utilities import add_postfix
# ========================================================================
def __compute_glm(arg_dict):
"""
Compute Generalized Linear Model coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(arg_dict)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__glm_{family}_{link}_agg(
({col_dep_var})::double precision,
({col_ind_var})::double precision[],
{rel_state}.{col_grp_state})
""")
if it.test(
"""
{iteration} >= {max_iter}
OR {schema_madlib}.__glm_loglik_diff(
_state_previous, _state_current) < {tolerance}
"""):
it.final()
break
return iterationCtrl.iteration
# ========================================================================
def glm(schema_madlib, source_table, model_table, dependent_varname,
independent_varname, family_params=None, grouping_col=None,
optim_params=None, verbose=False, **kwargs):
"""
Train Genralized Linear Model
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param source_table Name of relation containing the training data
@param model_table Name of relation where model will be outputted
@param dependent_varname Name of dependent column in training data
@param independent_varname Name of independent column in training data (of type
DOUBLE PRECISION[])
@param family_params Distribution of dependent variable
@param grouping_col String of comma delimited group-by columns
@param optim_params Parameters for optimizer
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return A composite value which is __glm_result defined in glm.sql_in
"""
__glm_validate_args(schema_madlib, source_table, model_table, dependent_varname,
independent_varname, grouping_col)
family_params_dict = __extract_family_params(schema_madlib, family_params)
optim_params_dict = __extract_optim_params(schema_madlib, optim_params)
return __glm_compute(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, grouping_col, family_params_dict,
optim_params_dict, verbose=verbose, **kwargs)
# ========================================================================
def __glm_validate_args(schema_madlib, tbl_source, tbl_output, col_dep_var,
col_ind_var, grouping_col):
"""
Validate the arguments
"""
input_tbl_valid(tbl_source, 'GLM')
output_tbl_valid(tbl_output, 'GLM')
if col_dep_var is None or col_dep_var.strip() == '':
plpy.error("GLM error: Invalid dependent column name!")
if col_ind_var is None or col_ind_var.strip() == '':
plpy.error("GLM error: Invalid independent column name!")
if grouping_col:
cols_in_tbl_valid(tbl_source, _string_to_array_with_quotes(grouping_col), 'GLM')
intersect = frozenset(_string_to_array(grouping_col)).intersection(
frozenset(('coef', 'log_likelihood', 'std_err', 'z_stats',
'p_values', 'odds_ratios', 'condition_no',
'num_processed', 'num_missing_rows_skipped',
'variance_covariance', 'dispersion', 't_stats')))
if len(intersect) > 0:
plpy.error("GLM error: Conflicting grouping column name.\n"
"Predefined name(s) {0} are not allow!".format(
', '.join(intersect)))
return None
# ========================================================================
def __extract_family_params(schema_madlib, family_params):
family_params_types = dict(family=str, link=str)
family_params_dict = extract_keyvalue_params(family_params,
family_params_types)
# we use the first element as the default link function for the family
family_link = dict(
poisson=["log", "identity", "sqrt"],
gaussian=["identity", "log", "inverse"],
gamma=["inverse", "log", "identity"],
inverse_gaussian=["sqr_inverse", "identity", "log", "inverse"],
binomial=["logit", "probit"])
for k, v in family_params_dict.iteritems():
if k == "family":
if v not in family_link.keys():
plpy.error("GLM error: {param_value} is not a valid "
"family!".format(param_value=v))
if "family" not in family_params_dict.keys():
plpy.error("GLM error: Required parameter family is missing!")
if "link" in family_params_dict.keys():
if family_params_dict["link"] not in family_link[family_params_dict["family"]]:
plpy.error("GLM error: Invalid link function {link_func} for "
"family {family}!".format(link_func=family_params_dict["link"],
family=family_params_dict["family"]))
else:
# default link function
family_params_dict["link"] = family_link[family_params_dict["family"]][0]
return family_params_dict
# ========================================================================
def __extract_optim_params(schema_madlib, optim_params, module='GLM'):
default_dict = dict(max_iter=100, optimizer='irls', tolerance=1e-6)
optim_params_types = dict(max_iter=int, optimizer=str, tolerance=float)
optim_params_dict = extract_keyvalue_params(optim_params,
optim_params_types,
default_dict)
if optim_params_dict['max_iter'] <= 0:
plpy.error("{0} error: max_iter must be positive!".format(module))
if optim_params_dict['optimizer'] != 'irls':
plpy.error("{0} error: optimizer must be irls!".format(module))
if optim_params_dict['tolerance'] <= 0:
plpy.error("{0} error: tolerane must be positive!".format(module))
return optim_params_dict
# ========================================================================
def __glm_compute(schema_madlib, tbl_source, tbl_output, col_dep_var, col_ind_var,
grouping_col, family_params, optim_params, verbose=False, **kwargs):
"""
Create an output table (drop if exists) that contains the Generalized Linear Model
"""
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
args = {'schema_madlib': schema_madlib,
'rel_source': tbl_source,
'tbl_output': tbl_output,
'col_dep_var': col_dep_var,
'col_ind_var': col_ind_var,
'rel_state': unique_string(),
'col_grp_iteration': unique_string(),
'col_grp_state': unique_string(),
'state_type': schema_madlib + ".bytea8"
}
args.update(optim_params)
args.update(family_params)
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = explicit_bool_to_text(
tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_col = None
grouping_str = "NULL"
args['grouping_col'] = grouping_col
args['grouping_str'] = grouping_str
# for binomial distribution, the dependent variable is of type boolean.
# it's cast to integer here so that it can later be type cast to
# double precision before computation begins.
if family_params['family'] == 'binomial':
args['col_dep_var'] = "(" + col_dep_var + ")::integer"
# REAL COMPUTATION
iteration_run = __compute_glm(args)
if iteration_run >= optim_params['max_iter']:
plpy.warning("GLM warning: the computation did not converge in " +
str(optim_params['max_iter']) + " iterations!")
# output table
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
join_str = "," if grouping_str1 == "" else "join "
if family_params['family'] in ['poisson', 'binomial']:
res_str = """
(result).z_stats AS z_stats,
(result).p_values AS p_values,
(result).dispersion AS dispersion
"""
glm_result = "__glm_result_z_stats"
else:
res_str = """
(result).z_stats AS t_stats,
(result).p_values AS p_values,
(result).dispersion AS dispersion
"""
glm_result = "__glm_result_t_stats"
plpy.execute(
"""
DROP TABLE IF EXISTS {tbl_output};
CREATE TABLE {tbl_output} AS
SELECT
{grouping_str1}
(result).coef AS coef,
(result).loglik AS log_likelihood,
(result).std_err AS std_err,
{res_str},
(CASE WHEN result IS NULL THEN 0
ELSE (result).num_rows_processed
END)::bigint AS num_rows_processed,
(CASE WHEN result IS NULL THEN num_rows
ELSE num_rows - (result).num_rows_processed
END)::bigint AS num_rows_skipped,
{col_grp_iteration}::integer AS num_iterations
FROM (
SELECT
{col_grp_iteration}, {grouping_str1} result, num_rows
FROM (
( SELECT
{grouping_str1}
{schema_madlib}.{glm_result}({col_grp_state}) AS result,
{col_grp_iteration}
FROM
{rel_state}
) t
JOIN
( SELECT
{grouping_str1}
max({col_grp_iteration}) AS {col_grp_iteration}
FROM {rel_state}
GROUP BY {grouping_str2}
) s
USING ({grouping_str1} {col_grp_iteration})
) q1
{join_str}
( SELECT
{grouping_str1}
count(*) AS num_rows
FROM {rel_source}
GROUP BY {grouping_str2}
) q2
{using_str}
) q3
""".format(grouping_str1=grouping_str1,
grouping_str2=grouping_str2,
iteration_run=iteration_run,
using_str=using_str,
join_str=join_str,
res_str=res_str,
glm_result=glm_result,
**args))
# summary table
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups
FROM {tbl_output}
WHERE coef IS NULL
""".format(**args))[0]
all_groups = plpy.execute("""
SELECT count(*) AS num_all_groups
FROM {tbl_output}
""".format(**args))[0]
total_rows = plpy.execute("""
SELECT
sum(num_rows_processed) AS total_rows_processed,
sum(num_rows_skipped) AS total_rows_skipped
FROM {tbl_output}
""".format(tbl_output=tbl_output))[0]
args.update(failed_groups)
args.update(all_groups)
args.update(total_rows)
tbl_output_summary = add_postfix(tbl_output, "_summary")
plpy.execute("""
CREATE TABLE {tbl_output_summary} AS
SELECT
'glm'::varchar AS method,
'{rel_source}'::varchar AS source_table,
'{tbl_output}'::varchar AS out_table,
$madlib_super_quote${dcol}$madlib_super_quote$::varchar
AS dependent_varname,
$madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar
AS independent_varname,
'family={family}, ' ||
'link={link}'::varchar AS family_params,
{g_str}::text AS grouping_col,
'optimizer={optimizer}, ' ||
'max_iter={max_iter}, ' ||
'tolerance={tolerance}'::varchar AS optimizer_params,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{total_rows_processed}::bigint AS total_rows_processed,
{total_rows_skipped}::bigint AS total_rows_skipped
""".format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
tbl_output_summary=tbl_output_summary,
dcol=col_dep_var,
**args))
# clean up
plpy.execute("""DROP TABLE IF EXISTS pg_temp.{rel_state} """.format(**args))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None
# ========================================================================
def glm_help_msg(schema_madlib, message, **kwargs):
""" Help message for generalized linear regression model
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
------------------------------------------------------------------
SUMMARY
------------------------------------------------------------------
Generalized Linear Model:
Function to fit a generalized linear model, relating responses to linear combinations
of predictor variables.
For details on function usage:
SELECT {schema_madlib}.glm('usage')
For a small example on using the function:
SELECT {schema_madlib}.glm('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm(
source_table, -- name of input table
model_table, -- name of model table
dependent_varname, -- name of dependent variable
independent_varname, -- names of independent variables
family_params, -- parameters for family distribution and link function
usage:
'family=<family_name>,link=<link_function_name>'
supported values include:
'family=poisson,link=identity|log|sqrt' (default link: log)
'family=binomial,link=logit|probit' (default link: logit)
'family=gaussian,link=identity|log|inverse' (default link: identity)
'family=inverse_gaussian,link=identity|log|inverse|sqr_inverse' (default link: sqr_inverse i.e. 1/mu^2)
'family=gamma,link=identity|log|inverse' (default link: inverse)
grouping_col, -- optional, default NULL, names of columns to group-by
optimizer_params, -- optional, parameters for optimizer
usage:
'max_iter=<max_num_iterations>,optimizer=<optimizer_name>,tolerance=<tolerance_value>'
default values include:
max_iter=100
optimizer='irls'
tolerance=1e-6
verbose -- optional, default FALSE, whether to print debug info
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...> -- grouping columns
'coef' double precision[], -- vector of coefficients
'log_likelihood' double precision, -- log likelihood
'std_err' double precision[], -- vector of standard errors
'z_stats'/'t_stats' double precision[], -- vector of z-statistics if family=Poisson or Binomial; vector of t-statistics otherwise
'p_values' double precision[], -- vector of p-values
'dispersion' double precision[], -- dispersion parameter (if z-stats is used, dispersion is set to be constant 1)
'num_rows_processed' bigint, -- numbers of rows processed
'num_rows_skipped' bigint, -- numbers of rows skipped
'num_iterations' integer -- number of iterations run
A summary table named <out_table>_summary is also created at the same time, which has:
method varchar, -- modeling method name: 'glm'
source_table varchar, -- the data source table name
model_table varchar, -- the output table name
dependent_varname varchar, -- the dependent variable
independent_varname varchar, -- the independent variable
family_params varchar, -- family distribution and link function
grouping_col varchar -- grouping columns used in the regression
optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
num_all_groups integer, -- how many groups
num_failed_groups integer, -- how many groups' fitting processes failed
total_rows_processed bigint, -- total numbers of rows processed
total_rows_skipped bigint, -- total numbers of rows skipped
"""
elif message in ['example', 'examples']:
help_string = """
CREATE TABLE warpbreaks(
id serial,
breaks integer,
wool char(1),
tension char(1)
);
INSERT INTO warpbreaks(breaks, wool, tension) VALUES
(26, 'A', 'L'),
(30, 'A', 'L'),
(54, 'A', 'L'),
(25, 'A', 'L'),
(70, 'A', 'L'),
(52, 'A', 'L'),
(51, 'A', 'L'),
(26, 'A', 'L'),
(67, 'A', 'L'),
(18, 'A', 'M'),
(21, 'A', 'M'),
(29, 'A', 'M'),
(17, 'A', 'M'),
(12, 'A', 'M'),
(18, 'A', 'M'),
(35, 'A', 'M'),
(30, 'A', 'M'),
(36, 'A', 'M'),
(36, 'A', 'H'),
(21, 'A', 'H'),
(24, 'A', 'H'),
(18, 'A', 'H'),
(10, 'A', 'H'),
(43, 'A', 'H'),
(28, 'A', 'H'),
(15, 'A', 'H'),
(26, 'A', 'H'),
(27, 'B', 'L'),
(14, 'B', 'L'),
(29, 'B', 'L'),
(19, 'B', 'L'),
(29, 'B', 'L'),
(31, 'B', 'L'),
(41, 'B', 'L'),
(20, 'B', 'L'),
(44, 'B', 'L'),
(42, 'B', 'M'),
(26, 'B', 'M'),
(19, 'B', 'M'),
(16, 'B', 'M'),
(39, 'B', 'M'),
(28, 'B', 'M'),
(21, 'B', 'M'),
(39, 'B', 'M'),
(29, 'B', 'M'),
(20, 'B', 'H'),
(21, 'B', 'H'),
(24, 'B', 'H'),
(17, 'B', 'H'),
(13, 'B', 'H'),
(15, 'B', 'H'),
(15, 'B', 'H'),
(16, 'B', 'H'),
(28, 'B', 'H');
SELECT create_indicator_variables('warpbreaks', 'warpbreaks_dummy', 'wool,tension');
-- Drop output tables before calling the function
DROP TABLE IF EXISTS glm_model;
DROP TABLE IF EXISTS glm_model_summary;
SELECT glm('warpbreaks_dummy',
'glm_model',
'breaks',
'ARRAY[1.0,"wool_B","tension_M", "tension_H"]',
'family=poisson, link=log',
NULL,
'max_iter=100,optimizer=irls,tolerance=1e-6',
true);
SELECT * from glm_model;
"""
else:
help_string = "No such option. Use {schema_madlib}.glm('help')"
return help_string.format(schema_madlib=schema_madlib)
# ========================================================================
def glm_predict_help_msg(schema_madlib, message, **kwargs):
""" Help message for glm predict function
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Prediction function for generalized linear regression:
Estimate the conditional mean for the new predictors. The length of input
coefficients should match the number of variables in the new predictors.
For details on function usage:
SELECT {schema_madlib}.glm_predict('usage')
For a small example on using the function:
SELECT {schema_madlib}.glm_predict('example')
For prediction functions related to specific distributions:
SELECT {schema_madlib}.glm_predict_poisson('help')
SELECT {schema_madlib}.glm_predict_binomial('help')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict(
coef, -- array of coefficients derived from glm() function
col_ind_var, -- array of independent variables for new predictors
link -- string indicating the link function specifid in glm()
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated conditional
means for the new predictors.
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS warpbreaks, warpbreaks_dummy, glm_model, glm_model_summary;
CREATE TABLE warpbreaks(
id serial,
breaks integer,
wool char(1),
tension char(1)
);
INSERT INTO warpbreaks(breaks, wool, tension) VALUES
(26, 'A', 'L'),
(30, 'A', 'L'),
(54, 'A', 'L'),
(25, 'A', 'L'),
(70, 'A', 'L'),
(52, 'A', 'L'),
(51, 'A', 'L'),
(26, 'A', 'L'),
(67, 'A', 'L'),
(18, 'A', 'M'),
(21, 'A', 'M'),
(29, 'A', 'M'),
(17, 'A', 'M'),
(12, 'A', 'M'),
(18, 'A', 'M'),
(35, 'A', 'M'),
(30, 'A', 'M'),
(36, 'A', 'M'),
(36, 'A', 'H'),
(21, 'A', 'H'),
(24, 'A', 'H'),
(18, 'A', 'H'),
(10, 'A', 'H'),
(43, 'A', 'H'),
(28, 'A', 'H'),
(15, 'A', 'H'),
(26, 'A', 'H'),
(27, 'B', 'L'),
(14, 'B', 'L'),
(29, 'B', 'L'),
(19, 'B', 'L'),
(29, 'B', 'L'),
(31, 'B', 'L'),
(41, 'B', 'L'),
(20, 'B', 'L'),
(44, 'B', 'L'),
(42, 'B', 'M'),
(26, 'B', 'M'),
(19, 'B', 'M'),
(16, 'B', 'M'),
(39, 'B', 'M'),
(28, 'B', 'M'),
(21, 'B', 'M'),
(39, 'B', 'M'),
(29, 'B', 'M'),
(20, 'B', 'H'),
(21, 'B', 'H'),
(24, 'B', 'H'),
(17, 'B', 'H'),
(13, 'B', 'H'),
(15, 'B', 'H'),
(15, 'B', 'H'),
(16, 'B', 'H'),
(28, 'B', 'H');
SELECT create_indicator_variables('warpbreaks', 'warpbreaks_dummy', 'wool,tension');
-- Drop output tables before calling the function
DROP TABLE IF EXISTS glm_model;
DROP TABLE IF EXISTS glm_model_summary;
SELECT glm('warpbreaks_dummy',
'glm_model',
'breaks',
'ARRAY[1.0,"wool_B","tension_M", "tension_H"]',
'family=poisson, link=log',
NULL,
'max_iter=100,optimizer=irls,tolerance=1e-6',
true);
SELECT * from glm_model;
SELECT w.id, madlib.glm_predict(coef, ARRAY[1, "wool_B", "tension_M", "tension_H"]::float8[],'log') as mu
FROM warpbreaks_dummy w, glm_model m
ORDER BY w.id;
"""
else:
help_string = "No such option. Use {schema_madlib}.glm_predict('help')"
return help_string.format(schema_madlib=schema_madlib)
# ============================================================================
# Help messages for specialized prediction functions
# ============================================================================
def glm_predict_poisson_help_msg(schema_madlib, message, **kwargs):
""" Help message for glm predict function
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if message in ['usage', 'help', '?', '']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict_poisson(
coef, -- array of coefficients derived from glm() function
col_ind_var, -- array of independent variables for new predictors
link -- string indicating the link function specifid in glm()
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated conditional
mean for the new predictors, rounded to the nearest integral value.
For more details on glm predict functions:
SELECT {schema_madlib}.glm_predict('usage')
For examples:
SELECT {schema_madlib}.glm_predict('example')
"""
else:
help_string = "No such option. Use {schema_madlib}.glm_predict_poisson('help')"
return help_string.format(schema_madlib=schema_madlib)
def glm_predict_binomial_help_msg(schema_madlib, message, **kwargs):
""" Help message for glm predict function
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if message in ['usage', 'help', '?', '']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict_binomial(
coef, -- array of coefficients derived from glm() function
col_ind_var, -- array of independent variables for new predictors
link -- string indicating the link function specifid in glm()
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated output category
of the dependent variable as a boolean value.
For more details on glm predict functions:
SELECT {schema_madlib}.glm_predict('usage')
For examples:
SELECT {schema_madlib}.glm_predict('example')
"""
else:
help_string = "No such option. Use {schema_madlib}.glm_predict_binomial('help')"
return help_string.format(schema_madlib=schema_madlib)