blob: f5469d1456be20a9390d6b02d366825698d7db5e [file] [log] [blame]
# coding=utf-8
"""
@file multinom.py_in
@brief Multinomial regression: Driver functions
@namespace glm
@brief Generalized Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import columns_exist_in_table
from glm import __glm_validate_args
from glm import __extract_optim_params
# ========================================================================
def __compute_multinom(arg_dict):
"""
Compute Generalized Linear Model coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(arg_dict)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__multinom_{link}_agg(
({category_expr})::integer,
({col_ind_var})::double precision[],
{rel_state}.{col_grp_state},
{n_categories}::smallint)
""")
if it.test(
"""
{iteration} >= {max_iter}
OR {schema_madlib}.__multinom_loglik_diff(
_state_previous, _state_current) < {tolerance}
"""):
it.final()
break
return iterationCtrl.iteration
# ========================================================================
def multinom(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, ref_category, link_func,
grouping_col, optim_params, verbose, **kwargs):
category_list = __multinom_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, ref_category, link_func, grouping_col)
# default values
ref_category = category_list[0] if ref_category is None else ref_category
link_func = 'logit' if link_func is None else link_func
optim_params = '' if optim_params is None else optim_params
optim_params_dict = __extract_optim_params(schema_madlib,
optim_params,
'Multinom')
return __multinom_compute(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, ref_category, category_list, link_func,
grouping_col, optim_params_dict, verbose)
# ========================================================================
def __multinom_validate_args(
schema_madlib, source_table, model_table,
dependent_varname, independent_varname, ref_category,
link_func, grouping_col):
__glm_validate_args(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, grouping_col)
if link_func is not None and link_func not in ('logit'):
plpy.error("Multinom error: Invalid link function!\n"
"Only 'logit' is supported.")
category_list = plpy.execute("""
SELECT array_agg(category ORDER BY category) AS category_list
FROM (
SELECT distinct {dependent_varname} AS category
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
) subq
""".format(**locals()))[0]['category_list']
if len(category_list) == 0:
plpy.error("Multinom error: No non-null categories found!")
if len(category_list) == 1:
plpy.error("Multinom error: Only a single valid category found!")
if not (isinstance(category_list[0], int) or
isinstance(category_list[0], float) or
isinstance(category_list[0], long) or
isinstance(category_list[0], str)):
plpy.error("Multinom error: Given category type is not supported!\n"
"Only numeric, character, binary data and enumerated types "
"are supported. Particularly, if the category type is boolean,"
"please use glm() binomial family instead.")
category_list = [str(c) for c in category_list]
if ref_category is not None and ref_category not in category_list:
plpy.error("Multinom error: Given ref_category is not found!"
"'{ref_category}' is not found in source table {source_table}.".
format(**locals()))
# set the reference category in the first position of category list in order to
# map the reference category to integer 0
if ref_category is not None and ref_category != category_list[0]:
i = category_list.index(ref_category)
category_list[0], category_list[i] = category_list[i], category_list[0]
if grouping_col:
grouped_category_counts = plpy.execute("""
SELECT array_agg(category_count) AS counts
FROM (
SELECT count(distinct {dependent_varname}) AS category_count
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
GROUP BY {grouping_col}
) subq
""".format(**locals()))[0]['counts']
if any(c != len(category_list) for c in grouped_category_counts):
plpy.error("Multinom error: Categories are not consistent across "
"all groups!")
return category_list
# ========================================================================
def __multinom_compute(schema_madlib, tbl_source, tbl_output, col_dep_var,
col_ind_var, ref_category, category_list, link_func,
grouping_col, optim_params_dict, verbose):
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
args = {'schema_madlib': schema_madlib,
'rel_source': tbl_source,
'tbl_output': tbl_output,
'rel_state': unique_string(),
'col_dep_var': col_dep_var,
'col_ind_var': col_ind_var,
'col_grp_iteration': unique_string(),
'col_grp_state': unique_string(),
'col_n_tuples': unique_string(),
'ref_category': ref_category,
'n_categories': len(category_list),
'link': link_func,
'state_type': schema_madlib + ".bytea8"
}
args.update(optim_params_dict)
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = explicit_bool_to_text(
tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_col = None
grouping_str = "NULL"
args['grouping_col'] = grouping_col
args['grouping_str'] = grouping_str
# build the case when expression to convert category value to integer
# when aggregate is called
category_expr_tmp = """\n """.join([
"WHEN ({col_dep_var})::text = '{c}' THEN {i}".
format(col_dep_var=col_dep_var, c=c, i=i)
for i, c in enumerate(category_list)])
args['category_expr'] = "CASE " + category_expr_tmp + "\nEND"
# REAL COMPUTATION #
iteration_run = __compute_multinom(args)
# output table
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
join_str = "," if grouping_str1 == "" else "join "
glm_result = "__multinom_result"
args['category_str'] = ','.join([c for c in category_list])
q_out_table = """
DROP TABLE IF EXISTS {tbl_output};
CREATE TABLE {tbl_output} AS
SELECT
{grouping_str1}
category_list[index+1] AS category,
{schema_madlib}.index_2d_array((result).coef, index) AS coef,
(result).loglik AS log_likelihood,
{schema_madlib}.index_2d_array((result).std_err, index) AS std_err,
{schema_madlib}.index_2d_array((result).z_stats, index) AS z_stats,
{schema_madlib}.index_2d_array((result).p_values, index) AS p_values,
(CASE WHEN result IS NULL THEN 0
ELSE (result).num_rows_processed
END)::bigint AS num_rows_processed,
(CASE WHEN result IS NULL THEN num_rows
ELSE num_rows - (result).num_rows_processed
END)::bigint AS num_rows_skipped,
{col_grp_iteration}::integer AS num_iterations
FROM
(
SELECT
{col_grp_iteration}, {grouping_str1} result, num_rows
FROM
(
(
SELECT
{grouping_str1}
{schema_madlib}.__multinom_result(
{col_grp_state}) AS result,
{col_grp_iteration}
FROM
{rel_state}
) t
JOIN
(
SELECT
{grouping_str1}
max({col_grp_iteration}) AS {col_grp_iteration}
FROM {rel_state}
GROUP BY {grouping_str2}
) s
USING ({grouping_str1} {col_grp_iteration})
) q1
{join_str}
(
SELECT
{grouping_str1}
count(*) AS num_rows
FROM {rel_source}
GROUP BY {grouping_str2}
) q2
{using_str}
) q3,
(
SELECT generate_series(1, {n_categories}-1) AS index
) q4,
(
SELECT '{{{category_str}}}'::varchar[] AS category_list
) q5
""".format(grouping_str1=grouping_str1,
grouping_str2=grouping_str2,
iteration_run=iteration_run,
using_str=using_str,
join_str=join_str,
glm_result=glm_result,
**args)
# plpy.info(q_out_table)
plpy.execute(q_out_table)
# summary table
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups
FROM {tbl_output}
WHERE coef IS NULL
""".format(**args))[0]
all_groups = plpy.execute("""
SELECT count(*) AS num_all_groups
FROM {tbl_output}
""".format(**args))[0]
total_rows = plpy.execute("""
SELECT
sum(num_rows_processed) AS total_rows_processed,
sum(num_rows_skipped) AS total_rows_skipped
FROM {tbl_output}
""".format(tbl_output=tbl_output))[0]
args.update(failed_groups)
args.update(all_groups)
args.update(total_rows)
tbl_output_summary = add_postfix(tbl_output, "_summary")
plpy.execute("""
CREATE TABLE {tbl_output_summary} AS
SELECT
'multinom'::varchar AS method,
'{rel_source}'::varchar AS source_table,
'{tbl_output}'::varchar AS out_table,
$madlib_super_quote${col_dep_var}$madlib_super_quote$::varchar
AS dependent_varname,
$madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar
AS independent_varname,
'{ref_category}'::varchar AS ref_category,
'{category_str}'::varchar AS category_list,
'{link}'::varchar AS link_func,
{g_str}::varchar AS grouping_col,
'optimizer={optimizer}, ' ||
'max_iter={max_iter}, ' ||
'tolerance={tolerance}'::varchar AS optimizer_params,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{total_rows_processed}::bigint AS total_rows_processed,
{total_rows_skipped}::bigint AS total_rows_skipped
""".format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
tbl_output_summary=tbl_output_summary,
**args))
# clean up
plpy.execute("""
DROP TABLE IF EXISTS pg_temp.{rel_state}
""".format(**args))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None
# ========================================================================
def multinom_help_msg(schema_madlib, message, **kwargs):
""" Help message for multinomial linear regression model
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Multinomial Linear Model:
Currently only logit link functions are supported.
For more details on function usage:
SELECT {schema_madlib}.multinom('usage')
For a small example on using the function:
SELECT {schema_madlib}.multinom('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.multinom(
source_table, -- name of input table
model_table, -- name of model table
dependent_varname, -- name of dependent variable
independent_varname, -- names of independent variables
ref_category, -- optional, parameter for reference category
link_func, -- optional, parameter for link function
grouping_col, -- optional, default NULL, names of columns to group-by
optim_params, -- optional, parameters for optimizer
verbose -- optional, default FALSE, whether to print debug info
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...> -- grouping columns
'category' varchar, -- category value
'coef' double precision[], -- vector of coefficients
'log_likelihood' double precision, -- log likelihood
'std_err' double precision[], -- vector of standard errors
'z_stats' double precision[], -- vector of z-statistics
'p_values' double precision[], -- vector of p-values
'num_rows_processed' bigint, -- numbers of rows processed
'num_rows_skipped' bigint, -- numbers of rows skipped
'num_iterations' integer -- number of iterations run
A summary table named <out_table>_summary is also created at the same time, which has:
method varchar, -- modeling method name: 'multinom'
source_table varchar, -- the data source table name
model_table varchar, -- the output table name
dependent_varname varchar, -- the dependent variable
independent_varname varchar, -- the independent variable
ref_category varchar, -- reference category value
category_list varchar, -- all categories used for training
link_func varchar, -- link function
grouping_col varchar -- grouping columns used in the regression
optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
num_all_groups integer, -- how many groups
num_failed_groups integer, -- how many groups' fitting processes failed
total_rows_processed bigint, -- total numbers of rows processed
total_rows_skipped bigint, -- total numbers of rows skipped
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS test3;
CREATE TABLE test3 (
feat1 INTEGER,
feat2 INTEGER,
cat INTEGER
);
INSERT INTO test3(feat1, feat2, cat) VALUES
(1,35,1),
(2,33,0),
(3,39,1),
(1,37,1),
(2,31,1),
(3,36,0),
(2,36,1),
(2,31,1),
(2,41,1),
(2,37,1),
(1,44,1),
(3,33,2),
(1,31,1),
(2,44,1),
(1,35,1),
(1,44,0),
(1,46,0),
(2,46,1),
(2,46,2),
(3,49,1),
(2,39,0),
(2,44,1),
(1,47,1),
(1,44,1),
(1,37,2),
(3,38,2),
(1,49,0),
(2,44,0),
(3,61,2),
(1,65,2),
(3,67,1),
(3,65,2),
(1,65,2),
(2,67,2),
(1,65,2),
(1,62,2),
(3,52,2),
(3,63,2),
(2,59,2),
(3,65,2),
(2,59,0),
(3,67,2),
(3,67,2),
(3,60,2),
(3,67,2),
(3,62,2),
(2,54,2),
(3,65,2),
(3,62,2),
(2,59,2),
(3,60,2),
(3,63,2),
(3,65,2),
(2,63,1),
(2,67,2),
(2,65,2),
(2,62,2);
-- Run the multilogistic regression function.
DROP TABLE IF EXISTS test3_output;
DROP TABLE IF EXISTS test3_output_summary;
SELECT madlib.multinom('test3',
'test3_output',
'cat',
'ARRAY[1, feat1, feat2]',
'0',
'logit'
);
SELECT * from test3_output;
"""
else:
help_string = "No such option. Use {schema_madlib}.multinom('help')"
return help_string.format(schema_madlib=schema_madlib)
# ===============================================================================
# Multinomial prediction function
# ===============================================================================
def multinom_predict(schema_madlib, model_table, predict_table,
predicted_value_tab, predict_type, verbose,
id_column, **kwargs):
"""
Compute the predicted value for multinomial regresssion
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param model_table Name of table containing training result from
multinom()
@param predict_table Name of table containing new data to predict
@param predicted_value_tab Name of table to output the predict value
@param predict_type Type of predict value: 'response' or 'probabilities'
@param verbose whether the verbose is displayed
@param id_column Name of ID column in the input table
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
"""
# Validate the argument
input_tbl_valid(model_table, 'multinom_predict')
input_tbl_valid(predict_table, 'multinom_predict')
output_tbl_valid(predicted_value_tab, 'multinom_predict')
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
if predict_type is None:
predict_type = 'response'
model_table_summary = add_postfix(model_table, "_summary")
ref_category = plpy.execute("""
SELECT ref_category FROM {model_table_summary}
""".format(**locals()))[0]['ref_category']
ind_var = plpy.execute("""
SELECT independent_varname FROM {model_table_summary}
""".format(**locals()))[0]['independent_varname']
cate_list = plpy.execute("""
SELECT category_list FROM {model_table_summary}
""".format(**locals()))[0]['category_list']
cate_list = cate_list.split(',')
group_var = plpy.execute("""
SELECT grouping_col FROM {model_table_summary}
""".format(**locals()))[0]['grouping_col']
if group_var is not None:
cols_in_tbl_valid(predict_table, _string_to_array(group_var), 'multinom_predict')
group_var = group_var.split(',')
if group_var is None:
grp_clause = ""
else:
grp_clause = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var])
grp_clause = "WHERE " + grp_clause
if columns_exist_in_table(predict_table, [id_column], schema_madlib):
multinom_predict_id = id_column
else:
multinom_predict_id = 'multinom_predict_id'
if predict_type == 'response':
sql = """
CREATE TABLE {predicted_value_tab} AS
SELECT
subq2.{multinom_predict_id},
subq3.category AS category
FROM
(
SELECT
greatest(0, max_score) AS max_score,
{multinom_predict_id}
FROM
(
SELECT
max(
{schema_madlib}.array_dot(coef, {ind_var}::float8[])
) AS max_score,
{id} AS {multinom_predict_id}
FROM
{predict_table},
{model_table}
{grp_clause}
GROUP BY {id}
) subq
) subq2
LEFT JOIN
(
SELECT
{schema_madlib}.array_dot(coef, {ind_var}::float8[]) AS score,
{id} AS {multinom_predict_id},
category::TEXT
FROM
{predict_table},
{model_table}
{grp_clause}
UNION
SELECT
0 AS score,
{id} AS {multinom_predict_id},
'{ref_category}' AS category
FROM {predict_table}
) subq3
ON
(
subq2.{multinom_predict_id} = subq3.{multinom_predict_id}
AND
subq2.max_score=subq3.score
)
ORDER BY subq2.{multinom_predict_id};
""".format(id=id_column,**locals())
plpy.notice(sql)
plpy.execute(sql)
elif predict_type in ('probability', 'prob'):
score_format = '\n'.join([
",score_arr[{j}] as \"{c}\"".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
score_map = '\n'.join([
"WHEN category='{c}' THEN {j}".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
sql = """
CREATE TABLE {predicted_value_tab} AS
SELECT
{multinom_predict_id}
{score_format}
FROM
(
SELECT
{schema_madlib}.array_scalar_mult(
array_agg(score ORDER BY idx),
1. / {schema_madlib}.array_sum(
array_agg(score)::float8[]
)
) AS score_arr,
array_agg(category ORDER BY idx) AS cate_arr,
{multinom_predict_id}
FROM
(
SELECT
score,
{multinom_predict_id},
subq2.category,
idx
FROM
(
SELECT
exp({schema_madlib}.array_dot(
coef,
{ind_var}::float8[]
)
) AS score,
category,
{id} AS {multinom_predict_id}
FROM
{predict_table},
{model_table}
{grp_clause}
UNION
SELECT
1. AS score,
'{ref_category}' AS category,
{id} AS {multinom_predict_id}
FROM {predict_table}
) subq2
LEFT JOIN
(
SELECT
category,
CASE {score_map} END as idx
FROM
(
SELECT unnest(ARRAY{cate_list}) AS category
) subq
) subq3
ON (subq2.category = subq3.category)
ORDER BY {multinom_predict_id}, idx
) subq4
GROUP BY {multinom_predict_id}
) subq5
ORDER by {multinom_predict_id};
""".format(id=id_column,**locals())
plpy.notice(sql)
plpy.execute(sql)
else:
plpy.error("Invalid prediction type!\n")
return None
# ========== help message for the prediction function ======================
def multinom_predict_help_msg(schema_madlib, message, **kwargs):
""" Help message for prediction function for multinomial regression
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Prediction function for multinomial linear regression:
Estimate the conditional probility or give the response category given
a new set of predictors.
For more details on function usage:
SELECT {schema_madlib}.multinom_predict('usage')
For a small example on using the function:
SELECT {schema_madlib}.multinom_predict('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.multinom_predict(
model_table, -- Name of the table containing the output of multinom()
predict_table_input, -- Name of the table containing new data
output_table, -- Name of the table storing the result of predicted values
predict_type, -- Support two types: "reponse" or "probability"
verbose, -- Whether verbose is diplayed, default is FALSE
id_column -- Name of the id column in the input table
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the predicted category when predict_type
is response and probability when predict_type is probability.
"""
elif message in ['example', 'examples']:
help_string = """
-- run the training example first
ALTER TABLE test3 ADD COLUMN id SERIAL;
DROP TABLE IF EXISTS test3_predict;
SELECT multinom_predict('test3_out', 'test3', 'test3_predict', 'response', 'id');
SELECT * FROM test3_predict;
"""
else:
help_string = "No such option. Use {schema_madlib}.multinom_predict('help')"
return help_string.format(schema_madlib=schema_madlib)