blob: 1293f57585805d938bffc370997851bc92fb931e [file] [log] [blame]
# coding=utf-8
"""
@file multinom.py_in
@brief Multinomial regression: Driver functions
@namespace glm
@brief Generalized Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import columns_exist_in_table
from glm import __glm_validate_args
from glm import __extract_optim_params
# ========================================================================
def __compute_multinom(arg_dict):
"""
Compute Generalized Linear Model coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(arg_dict)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__multinom_{link}_agg(
({category_expr})::integer,
({col_ind_var})::double precision[],
{rel_state}.{col_grp_state},
{n_categories}::smallint)
""")
if it.test(
"""
{iteration} >= {max_iter}
OR {schema_madlib}.__multinom_loglik_diff(
_state_previous, _state_current) < {tolerance}
"""):
it.final()
break
return iterationCtrl.iteration
# ========================================================================
def multinom(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, ref_category, link_func,
grouping_col, optim_params, verbose, **kwargs):
category_list = __multinom_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, ref_category, link_func, grouping_col)
# default values
ref_category = category_list[0] if ref_category is None else ref_category
link_func = 'logit' if link_func is None else link_func
optim_params = '' if optim_params is None else optim_params
optim_params_dict = __extract_optim_params(schema_madlib,
optim_params,
'Multinom')
return __multinom_compute(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, ref_category, category_list, link_func,
grouping_col, optim_params_dict, verbose)
# ========================================================================
def __multinom_validate_args(
schema_madlib, source_table, model_table,
dependent_varname, independent_varname, ref_category,
link_func, grouping_col):
__glm_validate_args(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, grouping_col)
if link_func is not None and link_func not in ('logit'):
plpy.error("Multinom error: Invalid link function!\n"
"Only 'logit' is supported.")
category_list = plpy.execute("""
SELECT array_agg(category ORDER BY category) AS category_list
FROM (
SELECT distinct {dependent_varname} AS category
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
) subq
""".format(**locals()))[0]['category_list']
if len(category_list) == 0:
plpy.error("Multinom error: No non-null categories found!")
if len(category_list) == 1:
plpy.error("Multinom error: Only a single valid category found!")
if not (isinstance(category_list[0], int) or
isinstance(category_list[0], float) or
isinstance(category_list[0], long) or
isinstance(category_list[0], str)):
plpy.error("Multinom error: Given category type is not supported!\n"
"Only numeric, character, binary data and enumerated types "
"are supported. Particularly, if the category type is boolean,"
"please use glm() binomial family instead.")
category_list = [str(c) for c in category_list]
if ref_category is not None and ref_category not in category_list:
plpy.error("Multinom error: Given ref_category is not found!"
"'{ref_category}' is not found in source table {source_table}.".
format(**locals()))
# set the reference category in the first position of category list in order to
# map the reference category to integer 0
if ref_category is not None and ref_category != category_list[0]:
i = category_list.index(ref_category)
category_list[0], category_list[i] = category_list[i], category_list[0]
if grouping_col:
grouped_category_counts = plpy.execute("""
SELECT array_agg(category_count) AS counts
FROM (
SELECT count(distinct {dependent_varname}) AS category_count
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
GROUP BY {grouping_col}
) subq
""".format(**locals()))[0]['counts']
if any(c != len(category_list) for c in grouped_category_counts):
plpy.error("Multinom error: Categories are not consistent across "
"all groups!")
return category_list
# ========================================================================
def __multinom_compute(schema_madlib, tbl_source, tbl_output, col_dep_var,
col_ind_var, ref_category, category_list, link_func,
grouping_col, optim_params_dict, verbose):
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
args = {'schema_madlib': schema_madlib,
'rel_source': tbl_source,
'tbl_output': tbl_output,
'rel_state': unique_string(),
'col_dep_var': col_dep_var,
'col_ind_var': col_ind_var,
'col_grp_iteration': unique_string(),
'col_grp_state': unique_string(),
'col_n_tuples': unique_string(),
'ref_category': ref_category,
'n_categories': len(category_list),
'link': link_func,
'state_type': schema_madlib + ".bytea8"
}
args.update(optim_params_dict)
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = explicit_bool_to_text(
tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_col = None
grouping_str = "NULL"
args['grouping_col'] = grouping_col
args['grouping_str'] = grouping_str
# build the case when expression to convert category value to integer
# when aggregate is called
category_expr_tmp = """\n """.join([
"WHEN ({col_dep_var})::text = '{c}' THEN {i}".
format(col_dep_var=col_dep_var, c=c, i=i)
for i, c in enumerate(category_list)])
args['category_expr'] = "CASE " + category_expr_tmp + "\nEND"
# REAL COMPUTATION #
iteration_run = __compute_multinom(args)
# output table
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
join_str = "," if grouping_str1 == "" else "join "
glm_result = "__multinom_result"
args['category_str'] = ','.join([c for c in category_list])
q_out_table = """
DROP TABLE IF EXISTS {tbl_output};
CREATE TABLE {tbl_output} AS
SELECT
{grouping_str1}
category_list[index+1] AS category,
{schema_madlib}.index_2d_array((result).coef, index) AS coef,
(result).loglik AS log_likelihood,
{schema_madlib}.index_2d_array((result).std_err, index) AS std_err,
{schema_madlib}.index_2d_array((result).z_stats, index) AS z_stats,
{schema_madlib}.index_2d_array((result).p_values, index) AS p_values,
(CASE WHEN result IS NULL THEN 0
ELSE (result).num_rows_processed
END)::bigint AS num_rows_processed,
(CASE WHEN result IS NULL THEN num_rows
ELSE num_rows - (result).num_rows_processed
END)::bigint AS num_rows_skipped,
{col_grp_iteration}::integer AS num_iterations
FROM
(
SELECT
{col_grp_iteration}, {grouping_str1} result, num_rows
FROM
(
(
SELECT
{grouping_str1}
{schema_madlib}.__multinom_result(
{col_grp_state}) AS result,
{col_grp_iteration}
FROM
{rel_state}
) t
JOIN
(
SELECT
{grouping_str1}
max({col_grp_iteration}) AS {col_grp_iteration}
FROM {rel_state}
GROUP BY {grouping_str2}
) s
USING ({grouping_str1} {col_grp_iteration})
) q1
{join_str}
(
SELECT
{grouping_str1}
count(*) AS num_rows
FROM {rel_source}
GROUP BY {grouping_str2}
) q2
{using_str}
) q3,
(
SELECT generate_series(1, {n_categories}-1) AS index
) q4,
(
SELECT '{{{category_str}}}'::varchar[] AS category_list
) q5
""".format(grouping_str1=grouping_str1,
grouping_str2=grouping_str2,
iteration_run=iteration_run,
using_str=using_str,
join_str=join_str,
glm_result=glm_result,
**args)
# plpy.info(q_out_table)
plpy.execute(q_out_table)
# summary table
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups
FROM {tbl_output}
WHERE coef IS NULL
""".format(**args))[0]
all_groups = plpy.execute("""
SELECT count(*) AS num_all_groups
FROM {tbl_output}
""".format(**args))[0]
total_rows = plpy.execute("""
SELECT
sum(num_rows_processed) AS total_rows_processed,
sum(num_rows_skipped) AS total_rows_skipped
FROM {tbl_output}
""".format(tbl_output=tbl_output))[0]
args.update(failed_groups)
args.update(all_groups)
args.update(total_rows)
tbl_output_summary = add_postfix(tbl_output, "_summary")
plpy.execute("""
CREATE TABLE {tbl_output_summary} AS
SELECT
'multinom'::varchar AS method,
'{rel_source}'::varchar AS source_table,
'{tbl_output}'::varchar AS out_table,
$madlib_super_quote${col_dep_var}$madlib_super_quote$::varchar
AS dependent_varname,
$madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar
AS independent_varname,
'{ref_category}'::varchar AS ref_category,
'{category_str}'::varchar AS category_list,
'{link}'::varchar AS link_func,
{g_str}::varchar AS grouping_col,
'optimizer={optimizer}, ' ||
'max_iter={max_iter}, ' ||
'tolerance={tolerance}'::varchar AS optimizer_params,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{total_rows_processed}::bigint AS total_rows_processed,
{total_rows_skipped}::bigint AS total_rows_skipped
""".format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
tbl_output_summary=tbl_output_summary,
**args))
# clean up
plpy.execute("""
DROP TABLE IF EXISTS pg_temp.{rel_state}
""".format(**args))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None
# ========================================================================
def multinom_help_msg(schema_madlib, message, **kwargs):
""" Help message for multinomial linear regression model
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Multinomial Linear Model:
Currently only logit link functions are supported.
For more details on function usage:
SELECT {schema_madlib}.multinom('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.multinom(
source_table, -- name of input table
model_table, -- name of model table
dependent_varname, -- name of dependent variable
independent_varname, -- names of independent variables
ref_category, -- optional, parameter for reference category
link_func, -- optional, parameter for link function
grouping_col, -- optional, default NULL, names of columns to group-by
optim_params, -- optional, parameters for optimizer
verbose -- optional, default FALSE, whether to print debug info
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...> -- grouping columns
'category' varchar, -- category value
'coef' double precision[], -- vector of coefficients
'log_likelihood' double precision, -- log likelihood
'std_err' double precision[], -- vector of standard errors
'z_stats' double precision[], -- vector of z-statistics
'p_values' double precision[], -- vector of p-values
'num_rows_processed' bigint, -- numbers of rows processed
'num_rows_skipped' bigint, -- numbers of rows skipped
'num_iterations' integer -- number of iterations run
A summary table named <out_table>_summary is also created at the same time, which has:
method varchar, -- modeling method name: 'multinom'
source_table varchar, -- the data source table name
model_table varchar, -- the output table name
dependent_varname varchar, -- the dependent variable
independent_varname varchar, -- the independent variable
ref_category varchar, -- reference category value
category_list varchar, -- all categories used for training
link_func varchar, -- link function
grouping_col varchar -- grouping columns used in the regression
optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
num_all_groups integer, -- how many groups
num_failed_groups integer, -- how many groups' fitting processes failed
total_rows_processed bigint, -- total numbers of rows processed
total_rows_skipped bigint, -- total numbers of rows skipped
"""
else:
help_string = "No such option. Use {schema_madlib}.multinom('help')"
return help_string.format(schema_madlib=schema_madlib)
# ===============================================================================
# Multinomial prediction function
# ===============================================================================
def multinom_predict(schema_madlib, model_table, predict_table,
predicted_value_tab, predict_type, verbose,
id_column, **kwargs):
"""
Compute the predicted value for multinomial regresssion
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param model_table Name of table containing training result from
multinom()
@param predict_table Name of table containing new data to predict
@param predicted_value_tab Name of table to output the predict value
@param predict_type Type of predict value: 'response' or 'probabilities'
@param verbose whether the verbose is displayed
@param id_column Name of ID column in the input table
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
"""
# Validate the argument
input_tbl_valid(model_table, 'multinom_predict')
input_tbl_valid(predict_table, 'multinom_predict')
output_tbl_valid(predicted_value_tab, 'multinom_predict')
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
if predict_type is None:
predict_type = 'response'
model_table_summary = add_postfix(model_table, "_summary")
ref_category = plpy.execute("""
SELECT ref_category FROM {model_table_summary}
""".format(**locals()))[0]['ref_category']
ind_var = plpy.execute("""
SELECT independent_varname FROM {model_table_summary}
""".format(**locals()))[0]['independent_varname']
cate_list = plpy.execute("""
SELECT category_list FROM {model_table_summary}
""".format(**locals()))[0]['category_list']
cate_list = cate_list.split(',')
group_var = plpy.execute("""
SELECT grouping_col FROM {model_table_summary}
""".format(**locals()))[0]['grouping_col']
if group_var is not None:
cols_in_tbl_valid(predict_table, _string_to_array(group_var), 'multinom_predict')
group_var = group_var.split(',')
if group_var is None:
grp_clause = ""
else:
grp_clause = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var])
grp_clause = "WHERE " + grp_clause
if columns_exist_in_table(predict_table, [id_column], schema_madlib):
multinom_predict_id = id_column
else:
multinom_predict_id = 'multinom_predict_id'
if predict_type == 'response':
sql = """
CREATE TABLE {predicted_value_tab} AS
SELECT
subq2.{multinom_predict_id},
subq3.category AS category
FROM
(
SELECT
greatest(0, max_score) AS max_score,
{multinom_predict_id}
FROM
(
SELECT
max(
{schema_madlib}.array_dot(coef, {ind_var}::float8[])
) AS max_score,
{id} AS {multinom_predict_id}
FROM
{predict_table},
{model_table}
{grp_clause}
GROUP BY {id}
) subq
) subq2
LEFT JOIN
(
SELECT
{schema_madlib}.array_dot(coef, {ind_var}::float8[]) AS score,
{id} AS {multinom_predict_id},
category::TEXT
FROM
{predict_table},
{model_table}
{grp_clause}
UNION
SELECT
0 AS score,
{id} AS {multinom_predict_id},
'{ref_category}' AS category
FROM {predict_table}
) subq3
ON
(
subq2.{multinom_predict_id} = subq3.{multinom_predict_id}
AND
subq2.max_score=subq3.score
)
ORDER BY subq2.{multinom_predict_id};
""".format(id=id_column,**locals())
plpy.notice(sql)
plpy.execute(sql)
elif predict_type in ('probability', 'prob'):
score_format = '\n'.join([
",score_arr[{j}] as \"{c}\"".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
score_map = '\n'.join([
"WHEN category='{c}' THEN {j}".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
sql = """
CREATE TABLE {predicted_value_tab} AS
SELECT
{multinom_predict_id}
{score_format}
FROM
(
SELECT
{schema_madlib}.array_scalar_mult(
array_agg(score ORDER BY idx),
1. / {schema_madlib}.array_sum(
array_agg(score)::float8[]
)
) AS score_arr,
array_agg(category ORDER BY idx) AS cate_arr,
{multinom_predict_id}
FROM
(
SELECT
score,
{multinom_predict_id},
subq2.category,
idx
FROM
(
SELECT
exp({schema_madlib}.array_dot(
coef,
{ind_var}::float8[]
)
) AS score,
category,
{id} AS {multinom_predict_id}
FROM
{predict_table},
{model_table}
{grp_clause}
UNION
SELECT
1. AS score,
'{ref_category}' AS category,
{id} AS {multinom_predict_id}
FROM {predict_table}
) subq2
LEFT JOIN
(
SELECT
category,
CASE {score_map} END as idx
FROM
(
SELECT unnest(ARRAY{cate_list}) AS category
) subq
) subq3
ON (subq2.category = subq3.category)
ORDER BY {multinom_predict_id}, idx
) subq4
GROUP BY {multinom_predict_id}
) subq5
ORDER by {multinom_predict_id};
""".format(id=id_column,**locals())
plpy.notice(sql)
plpy.execute(sql)
else:
plpy.error("Invalid prediction type!\n")
return None
# ========== help message for the prediction function ======================
def multinom_predict_help_msg(schema_madlib, message, **kwargs):
""" Help message for prediction function for multinomial regression
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Prediction function for multinomial linear regression:
Estimate the conditional probility or give the response category given
a new set of predictors.
For more details on function usage:
SELECT {schema_madlib}.multinom_predict('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.multinom_predict(
model_table, -- Name of the table containing the output of multinom()
predict_table_input, -- Name of the table containing new data
output_table, -- Name of the table storing the result of predicted values
predict_type, -- Support two types: "reponse" or "probability"
verbose, -- Whether verbose is diplayed, default is FALSE
id_column -- Name of the id column in the input table
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the predicted category when predict_type
is response and probability when predict_type is probability.
"""
else:
help_string = "No such option. Use {schema_madlib}.multinom_predict('help')"
return help_string.format(schema_madlib=schema_madlib)