blob: ffa7ccfd7b893348860879f099973b62db3f9109 [file] [log] [blame]
# coding=utf-8
"""
@file ordinal.py_in
@brief Ordinal regression: Driver functions
@namespace glm
@brief Ordinal Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from glm import __glm_validate_args
from glm import __extract_optim_params
# ========================================================================
def __compute_ordinal(arg_dict):
"""
Compute Ordinal Model coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(arg_dict)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__ordinal_{link}_agg(
({category_expr})::integer,
({col_ind_var})::double precision[],
{rel_state}.{col_grp_state},
{n_categories}::smallint)
""")
if it.test(
"""
{iteration} >= {max_iter}
OR {schema_madlib}.__ordinal_loglik_diff(
_state_previous, _state_current) < {tolerance}
"""):
it.final()
break
return iterationCtrl.iteration
# ========================================================================
def ordinal(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, cat_order, link_func,
grouping_col, optim_params, verbose, **kwargs):
category_list = __ordinal_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, cat_order, link_func, grouping_col)
# default values
link_func = 'logit' if link_func is None else link_func
optim_params = '' if optim_params is None else optim_params
optim_params_dict = __extract_optim_params(schema_madlib,
optim_params,
'Ordinal')
return __ordinal_compute(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, category_list, link_func, grouping_col,
optim_params_dict, verbose)
# ========================================================================
def __ordinal_validate_args(
schema_madlib, source_table, model_table, dependent_varname,
independent_varname, cat_order, link_func, grouping_col):
__glm_validate_args(schema_madlib, source_table, model_table,
dependent_varname, independent_varname, grouping_col)
if link_func is not None and link_func not in ('logit', 'probit'):
plpy.error("Ordinal error: Invalid link function!\n"
"Only 'logit' and 'probit' are supported.")
category_list = plpy.execute("""
SELECT array_agg(category ORDER BY category) AS category_list
FROM (
SELECT distinct {dependent_varname} AS category
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
) subq
""".format(**locals()))[0]['category_list']
if len(category_list) == 0:
plpy.error("Ordinal error: No non-null categories found!")
if len(category_list) == 1:
plpy.error("Ordinal error: Only a single valid category found!")
if not (isinstance(category_list[0], int) or
isinstance(category_list[0], float) or
isinstance(category_list[0], long) or
isinstance(category_list[0], str)):
plpy.error("Ordinal error: Given category type is not supported!\n"
"Only numeric, character, binary data and enumerated types "
"are supported. Particularly, if the category type is boolean,"
"please use glm() binomial family instead.")
category_list = [str(c) for c in category_list]
if cat_order is None:
order_list = sorted(category_list)
else:
order_list = [c.strip() for c in cat_order.split('<')]
if len(order_list)!=len(category_list):
plpy.error("Ordinal error: category order specification is not valid!")
for c in order_list:
if category_list.count(c) == 0:
plpy.error("Ordinal error: '{c}' is not found in source table {source_table}."
.format(**locals()))
category_list = order_list
if grouping_col:
grouped_category_counts = plpy.execute("""
SELECT array_agg(category_count) AS counts
FROM (
SELECT count(distinct {dependent_varname}) AS category_count
FROM {source_table}
WHERE {independent_varname} IS NOT NULL
AND NOT {schema_madlib}.array_contains_null(
{independent_varname})
AND {dependent_varname} IS NOT NULL
GROUP BY {grouping_col}
) subq
""".format(**locals()))[0]['counts']
if any(c != len(category_list) for c in grouped_category_counts):
plpy.error("Ordinal error: Categories are not consistent across "
"all groups!")
return category_list
# ========================================================================
def __ordinal_compute(
schema_madlib, tbl_source, tbl_output, col_dep_var, col_ind_var, category_list,
link_func, grouping_col, optim_params_dict, verbose):
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
args = {'schema_madlib': schema_madlib,
'rel_source': tbl_source,
'tbl_output': tbl_output,
'col_dep_var': col_dep_var,
'col_ind_var': col_ind_var,
'rel_state': unique_string(),
'col_grp_iteration': unique_string(),
'col_grp_state': unique_string(),
'state_type': schema_madlib + ".bytea8",
'n_categories': len(category_list),
'link': link_func}
args.update(optim_params_dict)
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = explicit_bool_to_text(
tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_col = None
grouping_str = "NULL"
args['grouping_col'] = grouping_col
args['grouping_str'] = grouping_str
# build the case when expression to convert category value to integer
# when aggregate is called
category_expr_tmp = '\n'.join([
"WHEN {d}::text = '{c}' THEN {i}".format(d=col_dep_var, c=c, i=i)
for i, c in enumerate(category_list)])
args['category_expr'] = "CASE " + category_expr_tmp + "\nEND"
iteration_run = __compute_ordinal(args)
# output table
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
join_str = "," if grouping_str1 == "" else "join "
category_str = ','.join([c for c in category_list])
args['category_str'] = category_str
q_out_table = """
DROP TABLE IF EXISTS {tbl_output};
CREATE TABLE {tbl_output} AS
SELECT
{grouping_str1}
(result).coef_alpha AS coef_threshold,
(result).std_err_alpha AS std_err_threshold,
(result).z_stats_alpha AS z_stats_threshold,
(result).p_values_alpha AS p_values_threshold,
(result).loglik AS log_likelihood,
(result).coef_beta AS coef_feature,
(result).std_err_beta AS std_err_feature,
(result).z_stats_beta AS z_stats_feature,
(result).p_values_beta AS p_values_feature,
(CASE WHEN result IS NULL THEN 0
ELSE (result).num_rows_processed
END)::bigint AS num_rows_processed,
(CASE WHEN result IS NULL THEN num_rows
ELSE num_rows - (result).num_rows_processed
END)::bigint AS num_rows_skipped,
{col_grp_iteration}::integer AS num_iterations
FROM
(
SELECT
{col_grp_iteration}, {grouping_str1} result, num_rows
FROM
(
(
SELECT
{grouping_str1}
{schema_madlib}.__ordinal_result({col_grp_state}) AS result,
{col_grp_iteration}
FROM
{rel_state}
) t
JOIN
(
SELECT
{grouping_str1}
max({col_grp_iteration}) AS {col_grp_iteration}
FROM {rel_state}
GROUP BY {grouping_str2}
) s
USING ({grouping_str1} {col_grp_iteration})
) q1
{join_str}
(
SELECT
{grouping_str1}
count(*) AS num_rows
FROM {rel_source}
GROUP BY {grouping_str2}
) q2
{using_str}
) q3
""".format(grouping_str1=grouping_str1,
grouping_str2=grouping_str2,
iteration_run=iteration_run,
using_str=using_str,
join_str=join_str,
**args)
# plpy.info(q_out_table)
plpy.execute(q_out_table)
# summary table
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups
FROM {tbl_output}
WHERE coef_threshold IS NULL
""".format(**args))[0]
all_groups = plpy.execute("""
SELECT count(*) AS num_all_groups
FROM {tbl_output}
""".format(**args))[0]
total_rows = plpy.execute("""
SELECT
sum(num_rows_processed) AS total_rows_processed,
sum(num_rows_skipped) AS total_rows_skipped
FROM {tbl_output}
""".format(tbl_output=tbl_output))[0]
args.update(failed_groups)
args.update(all_groups)
args.update(total_rows)
tbl_output_summary = add_postfix(tbl_output, "_summary")
plpy.execute("""
CREATE TABLE {tbl_output_summary} AS
SELECT
'ordinal'::varchar AS method,
'{rel_source}'::varchar AS source_table,
'{tbl_output}'::varchar AS out_table,
$msq${col_dep_var}$msq$::varchar AS dependent_varname,
$msq${col_ind_var}$msq$::varchar AS independent_varname,
'{category_str}'::varchar AS category_list,
'{link}'::varchar AS link_func,
{g_str}::varchar AS grouping_col,
'optimizer={optimizer}, ' ||
'max_iter={max_iter}, ' ||
'tolerance={tolerance}'::varchar AS optimizer_params,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups,
{total_rows_processed}::bigint AS total_rows_processed,
{total_rows_skipped}::bigint AS total_rows_skipped
""".format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
tbl_output_summary=tbl_output_summary,
**args))
# clean up
plpy.execute("""
DROP TABLE IF EXISTS pg_temp.{rel_state}
""".format(**args))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None
# ========================================================================
def ordinal_help_msg(schema_madlib, message, **kwargs):
""" Help message for ordinalial linear regression model
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Ordinal Linear Model:
Currently logit and probit link functions are supported.
For more details on function usage:
SELECT {schema_madlib}.ordinal('usage')
For a small example on using the function:
SELECT {schema_madlib}.ordinal('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.ordinal(
source_table, -- name of input table
model_table, -- name of model table
dependent_varname, -- name of dependent variable
independent_varname, -- names of independent variables
cat_order, -- category order specified by '<'
link_func, -- optional, parameter for link function
grouping_col, -- optional, default NULL, names of columns to group-by
optim_params, -- optional, parameters for optimizer
verbose -- optional, default FALSE, whether to print debug info
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
<...> -- grouping columns
'coef_threshold' double precision[], -- vector of threshold coefficients
'std_err_threshold' double precision[], -- vector of standard errors for threshold coefficients
'z_stats_threshold' double precision[], -- vector of z-statistics for threshold coefficients
'p_values_threshold' double precision[], -- vector of p-values for threshold coefficients
'log_likelihood' double precision, -- log likelihood
'coef_feature' double precision[], -- vector of feature coefficients
'std_err_feature' double precision[], -- vector of standard errors for feature coefficients
'z_stats_feature' double precision[], -- vector of z-statistics for feature coefficients
'p_values_feature' double precision[], -- vector of p-values for feature coefficients
'num_rows_processed' bigint, -- numbers of rows processed
'num_rows_skipped' bigint, -- numbers of rows skipped
'num_iterations' integer -- number of iterations run
A summary table named <out_table>_summary is also created at the same time, which has:
method varchar, -- modeling method name: 'ordinal'
source_table varchar, -- the data source table name
model_table varchar, -- the output table name
dependent_varname varchar, -- the dependent variable
independent_varname varchar, -- the independent variable
category_list varchar, -- ordered categories used for training
link_func varchar, -- link function: logit and probit supported
grouping_col varchar -- grouping columns used in the regression
optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
num_all_groups integer, -- how many groups
num_failed_groups integer, -- how many groups' fitting processes failed
total_rows_processed bigint, -- total numbers of rows processed
total_rows_skipped bigint, -- total numbers of rows skipped
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS test3;
CREATE TABLE test3 (
feat1 INTEGER,
feat2 INTEGER,
cat INTEGER
);
INSERT INTO test3(feat1, feat2, cat) VALUES
(1,35,1),
(2,33,0),
(3,39,1),
(1,37,1),
(2,31,1),
(3,36,0),
(2,36,1),
(2,31,1),
(2,41,1),
(2,37,1),
(1,44,1),
(3,33,2),
(1,31,1),
(2,44,1),
(1,35,1),
(1,44,0),
(1,46,0),
(2,46,1),
(2,46,2),
(3,49,1),
(2,39,0),
(2,44,1),
(1,47,1),
(1,44,1),
(1,37,2),
(3,38,2),
(1,49,0),
(2,44,0),
(3,61,2),
(1,65,2),
(3,67,1),
(3,65,2),
(1,65,2),
(2,67,2),
(1,65,2),
(1,62,2),
(3,52,2),
(3,63,2),
(2,59,2),
(3,65,2),
(2,59,0),
(3,67,2),
(3,67,2),
(3,60,2),
(3,67,2),
(3,62,2),
(2,54,2),
(3,65,2),
(3,62,2),
(2,59,2),
(3,60,2),
(3,63,2),
(3,65,2),
(2,63,1),
(2,67,2),
(2,65,2),
(2,62,2);
-- Run the ordinal logistic regression function.
DROP TABLE IF EXISTS test3_output;
DROP TABLE IF EXISTS test3_output_summary;
SELECT madlib.ordinal('test3',
'test3_output',
'cat',
'ARRAY[feat1, feat2]',
'0<1<2',
'logit'
);
SELECT * from test3_output;
"""
else:
help_string = "No such option. Use {schema_madlib}.ordinal('help')"
return help_string.format(schema_madlib=schema_madlib)
# ===============================================================================
# Ordinal prediction function
# ===============================================================================
def ordinal_predict(schema_madlib, model_table, predict_table,
predicted_value_tab, predict_type, verbose, **kwargs):
"""
Compute the predicted value for ordinalial regresssion
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param model_table Name of table containing training result from
ordinal()
@param predict_table Name of table containing new data to predict
@param predicted_value_tab Name of table to output the predict value
@param predict_type Type of predict value: 'response' or 'probabilities'
@param verbose Whether verbose will be displayed
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
"""
# Validate the argument
input_tbl_valid(model_table, 'ordinal_predict')
input_tbl_valid(predict_table, 'ordinal_predict')
output_tbl_valid(predicted_value_tab, 'ordinal_predict')
cols_in_tbl_valid(predict_table, _string_to_array("id"), 'ordinali_predict')
if verbose:
plpy.execute("SET client_min_messages TO info")
else:
plpy.execute("SET client_min_messages TO warning")
if predict_type is None:
predict_type = 'probability'
model_table_summary = add_postfix(model_table, "_summary")
ind_var = plpy.execute("""
SELECT independent_varname FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['independent_varname']
cate_list = plpy.execute("""
SELECT category_list FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['category_list']
cate_list = cate_list.split(',')
cate_list_len = len(cate_list)
cate_list_len_minus_one = cate_list_len - 1
group_var = plpy.execute("""
SELECT grouping_col FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['grouping_col']
if group_var is not None:
cols_in_tbl_valid(predict_table, _string_to_array_with_quotes(group_var), 'ordinal_predict')
group_var = group_var.split(',')
if group_var is None:
grp_clause1 = ""
grp_clause2 = ""
grp_clause3 = ""
else:
grp_clause1 = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var])
grp_clause1 = "WHERE " + grp_clause1
grp_clause2 = ", ".join(["{model_table}.{c} as {c}".format(c=c, model_table=model_table) for c in group_var])
grp_clause2 = ", " + grp_clause2
grp_clause3 = " AND ".join(["subq2.{c} = subq3.{c}".format(c=c) for c in group_var])
grp_clause3 = "WHERE " + grp_clause3
link_func = plpy.execute("""
SELECT link_func FROM {model_table_summary}
""".format(model_table_summary=model_table_summary))[0]['link_func']
if link_func == "logit":
link_clause = "exp(gamma)/(1+exp(gamma))"
elif link_func == "probit":
link_clause = "{schema_madlib}.normal_cdf(gamma)".format(schema_madlib=schema_madlib)
else:
plpy.error("Invalid link function!\n")
if predict_type == 'probability':
score_format = '\n'.join([
"indpr[{j}] as \"{c}\",".
format(j=i+1, c=c)
for i, c in enumerate(cate_list)])
plpy.execute("""
CREATE TABLE {predicted_value_tab} AS
SELECT {score_format}
id
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT
generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id)subq6, {model_table_summary}
)subq7
""".format(**locals()))
elif predict_type == 'response':
plpy.execute("""
CREATE TABLE {predicted_value_tab} AS
SELECT
subq8.id as id,
subq8.category as category
FROM
(
SELECT
id,
max(prob) as max_prob
FROM
(
SELECT
id,
unnest(indpr) as prob,
unnest(catearray) as category
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id
)subq6,
{model_table_summary}
)subq7
)subq8
GROUP BY id
)subq9,
(
SELECT
id,
unnest(indpr) as prob,
unnest(catearray) as category
FROM
(
SELECT
id,
{schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]),
array_cat(ARRAY[0]::float8[],prarray)) as indpr,
array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray
FROM
(
SELECT
id,
array_agg(cumpr ORDER BY idx) as prarray,
array_agg(category ORDER BY idx) as catearray
FROM
(
SELECT
id,
{link_clause} as cumpr,
category,
idx
FROM
(
SELECT
id,
(alpha-xbeta) as gamma,
category,
idx
FROM
(
SELECT
id,
({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta
{grp_clause2}
FROM
{predict_table},
{model_table}
{grp_clause1}
)subq2,
(
SELECT
i as idx,
coef_threshold[i] as alpha,
(string_to_array(category_list,','))[i] as category
{grp_clause2}
FROM
{model_table},
{model_table_summary},
(
SELECT generate_series(1,{cate_list_len_minus_one}) as i
)subq1
)subq3
{grp_clause3}
)subq4
)subq5
GROUP by id
)subq6, {model_table_summary}
)subq7
)subq8
WHERE subq9.id = subq8.id and subq9.max_prob = subq8.prob
""".format(**locals()))
else:
plpy.error("Invalid prediction type!\n")
return None
# ========== help message for the prediction function ======================
def ordinal_predict_help_msg(schema_madlib, message, **kwargs):
""" Help message for prediction function for ordinal regression
@param message A string, the help message indicator
Returns:
A string, contains the help message
"""
if not message:
help_string = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Prediction function for ordinal linear regression:
Estimate the conditional probability or give the response category given
a new set of predictors.
For more details on function usage:
SELECT {schema_madlib}.ordinal_predict('usage')
For a small example on using the function:
SELECT {schema_madlib}.ordinal_predict('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------------
USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.ordinal_predict(
model_table, -- Name of the table containing the output of ordinal()
predict_table_input, -- Name of the table containing new data
output_table, -- Name of the table storing the result of predicted values
predict_type -- Support two types: "reponse" or "probability"
verbose -- Whether verbose is displayed
);
------------------------------------------------------------------
OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the predicted category when predict_type
is response and probability when predict_type is probability.
"""
elif message in ['example', 'examples']:
help_string = """
-- run the training example first
ALTER TABLE test3 ADD COLUMN id SERIAL;
DROP TABLE IF EXISTS test3_predict;
SELECT ordinal_predict('test3_out', 'test3', 'test3_predict', 'probability');
SELECT * FROM test3_predict;
"""
else:
help_string = "No such option. Use {schema_madlib}.ordinal_predict('help')"
return help_string.format(schema_madlib=schema_madlib)