blob: 4cf4cfd52e817bd9714d0d2267af99edf530b7ef [file] [log] [blame]
# coding=utf-8
"""
@file ordinal.py_in
@brief Ordinal regression: Driver functions
@namespace glm
@brief Ordinal Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from glm import __glm_validate_args
from glm import __extract_optim_params
# ========================================================================
def __compute_ordinal(arg_dict):
"""
Compute Ordinal Model coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(arg_dict)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__ordinal_{link}_agg(
({category_expr})::integer,
({col_ind_var})::double precision[],
{rel_state}.{col_grp_state},
{n_categories}::smallint)
""")
if it.test(
"""
{iteration} >= {max_iter}
OR {schema_madlib}.__ordinal_loglik_diff(
_state_previous, _state_current) < {tolerance}
"""):
it.final()
break
return iterationCtrl.iteration
# ========================================================================
def ordinal(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, cat_order, link_func,
grouping_col, optim_params, verbose, **kwargs):
category_list = __ordinal_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, cat_order, link_func, grouping_col)
# default values
link_func = 'logit' if link_func is None else link_func
optim_params = '' if optim_params is None else optim_params
optim_params_dict = __extract_optim_params(schema_madlib,
optim_params,
'Ordinal')
return __ordinal_compute(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, category_list, link_func, grouping_col,
optim_params_dict, verbose)
# ========================================================================
def __ordinal_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, cat_order, link_func, grouping_col):
__glm_validate_args(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, grouping_col)
if link_func is not None and link_func not in ('logit', 'probit'):
plpy.error("Ordinal error: Invalid link function!\n"
"Only 'logit' and 'probit' are supported.")
category_list = plpy.execute("""
SELECT array_agg(category ORDER BY category) AS category_list
FROM (
SELECT distinct {dependent_varname} AS category
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
) subq
""".format(**locals()))[0]['category_list']
if len(category_list) == 0:
plpy.error("Ordinal error: No non-null categories found!")
if len(category_list) == 1:
plpy.error("Ordinal error: Only a single valid category found!")
if not (isinstance(category_list[0], int) or
isinstance(category_list[0], float) or
isinstance(category_list[0], long) or
isinstance(category_list[0], str)):
plpy.error("Ordinal error: Given category type is not supported!\n"
"Only numeric, character, binary data and enumerated types "
"are supported. Particularly, if the category type is boolean,"
"please use glm() binomial family instead.")
category_list = [str(c) for c in category_list]
if cat_order is None:
order_list = sorted(category_list)
else:
order_list = [c.strip() for c in cat_order.split('<')]
if len(order_list)!=len(category_list):
plpy.error("Ordinal error: category order specification is not valid!")
for c in order_list:
if category_list.count(c) == 0:
plpy.error("Ordinal error: '{c}' is not found in source table {source_table}."
.format(**locals()))
category_list = order_list
if grouping_col:
grouped_category_counts = plpy.execute("""
SELECT array_agg(category_count) AS counts
FROM (
SELECT count(distinct {dependent_varname}) AS category_count
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
GROUP BY {grouping_col}
) subq
""".format(**locals()))[0]['counts']
if any(c != len(category_list) for c in grouped_category_counts):
plpy.error("Ordinal error: Categories are not consistent across "
"all groups!")
return category_list
# ========================================================================
def __ordinal_compute(
schema_madlib, tbl_source, tbl_output, col_dep_var, col_ind_var, category_list,
link_func, grouping_col, optim_params_dict, verbose):
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
args = {'schema_madlib': schema_madlib,
'rel_source': tbl_source,
'tbl_output': tbl_output,
'col_dep_var': col_dep_var,
'col_ind_var': col_ind_var,
'rel_state': unique_string(),
'col_grp_iteration': unique_string(),
'col_grp_state': unique_string(),
'state_type': schema_madlib + ".bytea8",
'n_categories': len(category_list),
'link': link_func}
args.update(optim_params_dict)
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = explicit_bool_to_text(
tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_col = None
grouping_str = "NULL"
args['grouping_col'] = grouping_col
args['grouping_str'] = grouping_str
# build the case when expression to convert category value to integer
# when aggregate is called
category_expr_tmp = '\n'.join([
"WHEN {d}::text = '{c}' THEN {i}".format(d=col_dep_var, c=c, i=i)
for i, c in enumerate(category_list)])
args['category_expr'] = "CASE " + category_expr_tmp + "\nEND"
iteration_run = __compute_ordinal(args)
# output table
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
join_str = "," if grouping_str1 == "" else "join "
category_str = ','.join([c for c in category_list])
args['category_str'] = category_str
q_out_table = """
CREATE TABLE {tbl_output} AS
SELECT
{grouping_str1}
(result).coef_alpha AS coef_threshold,
(result).std_err_alpha AS std_err_threshold,
(result).z_stats_alpha AS z_stats_threshold,
(result).p_values_alpha AS p_values_threshold,
(result).loglik AS log_likelihood,
(result).coef_beta AS coef_feature,
(result).std_err_beta AS std_err_feature,
(result).z_stats_beta AS z_stats_feature,
(result).p_values_beta AS p_values_feature,
(CASE WHEN result IS NULL THEN 0
ELSE (result).num_rows_processed
END)::bigint AS num_rows_processed,
(CASE WHEN result IS NULL THEN num_rows
ELSE num_rows - (result).num_rows_processed
END)::bigint AS num_rows_skipped,
{col_grp_iteration}::integer AS num_iterations
FROM
(
SELECT
{col_grp_iteration}, {grouping_str1} result, num_rows
FROM
(
(
SELECT
{grouping_str1}
{schema_madlib}.__ordinal_result({col_grp_state}) AS result,
{col_grp_iteration}
FROM
{rel_state}
) t
JOIN
(
SELECT
{grouping_str1}
max({col_grp_iteration}) AS {col_grp_iteration}
FROM {rel_state}
GROUP BY {grouping_str2}
) s
USING ({grouping_str1} {col_grp_iteration})
) q1
{join_str}
(
SELECT
{grouping_str1}
count(*) AS num_rows
FROM {rel_source}
GROUP BY {grouping_str2}
) q2
{using_str}
) q3
""".format(grouping_str1=grouping_str1,
grouping_str2=grouping_str2,
iteration_run=iteration_run,
using_str=using_str,
join_str=join_str,
**args)
# plpy.info(q_out_table)
plpy.execute(q_out_table)
# summary table
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups
FROM {tbl_output}
WHERE coef_threshold IS NULL
""".format(**args))[0]
all_groups = plpy.execute("""
SELECT count(*) AS num_all_groups
FROM {tbl_output}
""".format(**args))[0]
total_rows = plpy.execute("""
SELECT
sum(num_rows_processed) AS total_rows_processed,
sum(num_rows_skipped) AS total_rows_skipped
FROM {tbl_output}
""".format(tbl_output=tbl_output))[0]
args.update(failed_groups)
args.update(all_groups)
args.update(total_rows)
tbl_output_summary = add_postfix(tbl_output, "_summary")
plpy.execute("""
CREATE TABLE {tbl_output_summary} AS
SELECT
'ordinal'::varchar AS method,
'{rel_source}'::varchar AS source_table,
'{tbl_output}'::varchar AS out_table,
$msq${col_dep_var}$msq$::varchar AS dependent_varname,
$msq${col_ind_var}$msq$::varchar AS independent_varname,
'{category_str}'::varchar AS category_list,
'{link}'::varchar AS link_func,
{g_str}::varchar AS grouping_col,
'optimizer={optimizer}, ' ||
'max_iter={max_iter}, ' ||
'tolerance={tolerance}'::varchar AS optimizer_params,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{total_rows_processed}::bigint AS total_rows_processed,
{total_rows_skipped}::bigint AS total_rows_skipped
""".format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
tbl_output_summary=tbl_output_summary,
**args))
# clean up
plpy.execute("""
DROP TABLE IF EXISTS pg_temp.{rel_state}
""".format(**args))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None
# ========================================================================
def ordinal_help_msg(schema_madlib, message, **kwargs):
""" Help message for ordinalial linear regression model
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Ordinal Linear Model:
Currently logit and probit link functions are supported.
For more details on function usage:
SELECT {schema_madlib}.ordinal('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.ordinal(
source_table, -- name of input table
model_table, -- name of model table
dependent_varname, -- name of dependent variable
independent_varname, -- names of independent variables
cat_order, -- category order specified by '<'
link_func, -- optional, parameter for link function
grouping_col, -- optional, default NULL, names of columns to group-by
optim_params, -- optional, parameters for optimizer
verbose -- optional, default FALSE, whether to print debug info
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...> -- grouping columns
'coef_threshold' double precision[], -- vector of threshold coefficients
'std_err_threshold' double precision[], -- vector of standard errors for threshold coefficients
'z_stats_threshold' double precision[], -- vector of z-statistics for threshold coefficients
'p_values_threshold' double precision[], -- vector of p-values for threshold coefficients
'log_likelihood' double precision, -- log likelihood
'coef_feature' double precision[], -- vector of feature coefficients
'std_err_feature' double precision[], -- vector of standard errors for feature coefficients
'z_stats_feature' double precision[], -- vector of z-statistics for feature coefficients
'p_values_feature' double precision[], -- vector of p-values for feature coefficients
'num_rows_processed' bigint, -- numbers of rows processed
'num_rows_skipped' bigint, -- numbers of rows skipped
'num_iterations' integer -- number of iterations run
A summary table named <out_table>_summary is also created at the same time, which has:
method varchar, -- modeling method name: 'ordinal'
source_table varchar, -- the data source table name
model_table varchar, -- the output table name
dependent_varname varchar, -- the dependent variable
independent_varname varchar, -- the independent variable
category_list varchar, -- ordered categories used for training
link_func varchar, -- link function: logit and probit supported
grouping_col varchar -- grouping columns used in the regression
optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
num_all_groups integer, -- how many groups
num_failed_groups integer, -- how many groups' fitting processes failed
total_rows_processed bigint, -- total numbers of rows processed
total_rows_skipped bigint, -- total numbers of rows skipped
"""
else:
help_string = "No such option. Use {schema_madlib}.ordinal('help')"
return help_string.format(schema_madlib=schema_madlib)
# ===============================================================================
# Ordinal prediction function
# ===============================================================================
def ordinal_predict(schema_madlib, model_table, predict_table,
predicted_value_tab, predict_type, verbose, **kwargs):
"""
Compute the predicted value for ordinalial regresssion
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param model_table Name of table containing training result from
ordinal()
@param predict_table Name of table containing new data to predict
@param predicted_value_tab Name of table to output the predict value
@param predict_type Type of predict value: 'response' or 'probabilities'
@param verbose Whether verbose will be displayed
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
"""
# Validate the argument
input_tbl_valid(model_table, 'ordinal_predict')
input_tbl_valid(predict_table, 'ordinal_predict')
output_tbl_valid(predicted_value_tab, 'ordinal_predict')
cols_in_tbl_valid(predict_table, _string_to_array("id"), 'ordinali_predict')
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
if predict_type is None:
predict_type = 'probability'
model_table_summary = add_postfix(model_table, "_summary")
ind_var = plpy.execute("""
SELECT independent_varname FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['independent_varname']
cate_list = plpy.execute("""
SELECT category_list FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['category_list']
cate_list = cate_list.split(',')
cate_list_len = len(cate_list)
cate_list_len_minus_one = cate_list_len - 1
group_var = plpy.execute("""
SELECT grouping_col FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['grouping_col']
if group_var is not None:
cols_in_tbl_valid(predict_table, _string_to_array_with_quotes(group_var), 'ordinal_predict')
group_var = group_var.split(',')
if group_var is None:
grp_clause1 = ""
grp_clause2 = ""
grp_clause3 = ""
else:
grp_clause1 = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var])
grp_clause1 = "WHERE " + grp_clause1
grp_clause2 = ", ".join(["{model_table}.{c} as {c}".format(c=c, model_table=model_table) for c in group_var])
grp_clause2 = ", " + grp_clause2
grp_clause3 = " AND ".join(["subq2.{c} = subq3.{c}".format(c=c) for c in group_var])
grp_clause3 = "WHERE " + grp_clause3
link_func = plpy.execute("""
SELECT link_func FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['link_func']
if link_func == "logit":
link_clause = "exp(gamma)/(1+exp(gamma))"
elif link_func == "probit":
link_clause = "{schema_madlib}.normal_cdf(gamma)".format(schema_madlib=schema_madlib)
else:
plpy.error("Invalid link function!\n")
if predict_type == 'probability':
score_format = '\n'.join([
"indpr[{j}] as \"{c}\",".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
plpy.execute("""
CREATE TABLE {predicted_value_tab} AS
SELECT {score_format}
id
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT
generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id)subq6, {model_table_summary}
)subq7
""".format(**locals()))
elif predict_type == 'response':
plpy.execute("""
CREATE TABLE {predicted_value_tab} AS
SELECT
subq8.id as id,
subq8.category as category
FROM
(
SELECT
id,
max(prob) as max_prob
FROM
(
SELECT
id,
unnest(indpr) as prob,
unnest(catearray) as category
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id
)subq6,
{model_table_summary}
)subq7
)subq8
GROUP BY id
)subq9,
(
SELECT
id,
unnest(indpr) as prob,
unnest(catearray) as category
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id
)subq6, {model_table_summary}
)subq7
)subq8
WHERE subq9.id = subq8.id and subq9.max_prob = subq8.prob
""".format(**locals()))
else:
plpy.error("Invalid prediction type!\n")
return None
# ========== help message for the prediction function ======================
def ordinal_predict_help_msg(schema_madlib, message, **kwargs):
""" Help message for prediction function for ordinal regression
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Prediction function for ordinal linear regression:
Estimate the conditional probability or give the response category given
a new set of predictors.
For more details on function usage:
SELECT {schema_madlib}.ordinal_predict('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.ordinal_predict(
model_table, -- Name of the table containing the output of ordinal()
predict_table_input, -- Name of the table containing new data
output_table, -- Name of the table storing the result of predicted values
predict_type -- Support two types: "reponse" or "probability"
verbose -- Whether verbose is displayed
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the predicted category when predict_type
is response and probability when predict_type is probability.
"""
else:
help_string = "No such option. Use {schema_madlib}.ordinal_predict('help')"
return help_string.format(schema_madlib=schema_madlib)