| # coding=utf-8 |
| |
| """ |
| @file multinom.py_in |
| |
| @brief Multinomial regression: Driver functions |
| |
| @namespace glm |
| |
| @brief Generalized Linear Models: Driver functions |
| """ |
| import plpy |
| |
| from utilities.in_mem_group_control import GroupIterationController |
| from utilities.utilities import unique_string |
| from utilities.validate_args import explicit_bool_to_text |
| from utilities.utilities import _string_to_array |
| from utilities.utilities import _string_to_array_with_quotes |
| from utilities.utilities import add_postfix |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import columns_exist_in_table |
| |
| from glm import __glm_validate_args |
| from glm import __extract_optim_params |
| |
| # ======================================================================== |
| |
| |
| def __compute_multinom(arg_dict): |
| """ |
| Compute Generalized Linear Model coefficients |
| |
| This method serves as an interface to different optimization algorithms. |
| By default, iteratively reweighted least squares is used. |
| |
| @return Number of iterations that has been run |
| """ |
| |
| iterationCtrl = GroupIterationController(arg_dict) |
| |
| with iterationCtrl as it: |
| it.iteration = 0 |
| while True: |
| it.update( |
| """ |
| {schema_madlib}.__multinom_{link}_agg( |
| ({category_expr})::integer, |
| ({col_ind_var})::double precision[], |
| {rel_state}.{col_grp_state}, |
| {n_categories}::smallint) |
| """) |
| if it.test( |
| """ |
| {iteration} >= {max_iter} |
| OR {schema_madlib}.__multinom_loglik_diff( |
| _state_previous, _state_current) < {tolerance} |
| """): |
| it.final() |
| break |
| |
| return iterationCtrl.iteration |
| |
| # ======================================================================== |
| |
| |
| def multinom(schema_madlib, source_table, model_table, |
| dependent_varname, independent_varname, ref_category, link_func, |
| grouping_col, optim_params, verbose, **kwargs): |
| |
| category_list = __multinom_validate_args( |
| schema_madlib, source_table, model_table, dependent_varname, |
| independent_varname, ref_category, link_func, grouping_col) |
| |
| # default values |
| ref_category = category_list[0] if ref_category is None else ref_category |
| link_func = 'logit' if link_func is None else link_func |
| optim_params = '' if optim_params is None else optim_params |
| |
| optim_params_dict = __extract_optim_params(schema_madlib, |
| optim_params, |
| 'Multinom') |
| |
| return __multinom_compute( |
| schema_madlib, source_table, model_table, dependent_varname, |
| independent_varname, ref_category, category_list, link_func, |
| grouping_col, optim_params_dict, verbose) |
| |
| # ======================================================================== |
| |
| |
| def __multinom_validate_args( |
| schema_madlib, source_table, model_table, |
| dependent_varname, independent_varname, ref_category, |
| link_func, grouping_col): |
| |
| __glm_validate_args(schema_madlib, source_table, model_table, |
| dependent_varname, independent_varname, grouping_col) |
| |
| if link_func is not None and link_func not in ('logit'): |
| plpy.error("Multinom error: Invalid link function!\n" |
| "Only 'logit' is supported.") |
| |
| category_list = plpy.execute(""" |
| SELECT array_agg(category ORDER BY category) AS category_list |
| FROM ( |
| SELECT distinct {dependent_varname} AS category |
| FROM {source_table} |
| WHERE {independent_varname} IS NOT NULL |
| AND NOT {schema_madlib}.array_contains_null( |
| {independent_varname}) |
| AND {dependent_varname} IS NOT NULL |
| ) subq |
| """.format(**locals()))[0]['category_list'] |
| |
| if len(category_list) == 0: |
| plpy.error("Multinom error: No non-null categories found!") |
| if len(category_list) == 1: |
| plpy.error("Multinom error: Only a single valid category found!") |
| |
| if not (isinstance(category_list[0], int) or |
| isinstance(category_list[0], float) or |
| isinstance(category_list[0], long) or |
| isinstance(category_list[0], str)): |
| plpy.error("Multinom error: Given category type is not supported!\n" |
| "Only numeric, character, binary data and enumerated types " |
| "are supported. Particularly, if the category type is boolean," |
| "please use glm() binomial family instead.") |
| |
| category_list = [str(c) for c in category_list] |
| if ref_category is not None and ref_category not in category_list: |
| plpy.error("Multinom error: Given ref_category is not found!" |
| "'{ref_category}' is not found in source table {source_table}.". |
| format(**locals())) |
| |
| # set the reference category in the first position of category list in order to |
| # map the reference category to integer 0 |
| if ref_category is not None and ref_category != category_list[0]: |
| i = category_list.index(ref_category) |
| category_list[0], category_list[i] = category_list[i], category_list[0] |
| |
| if grouping_col: |
| grouped_category_counts = plpy.execute(""" |
| SELECT array_agg(category_count) AS counts |
| FROM ( |
| SELECT count(distinct {dependent_varname}) AS category_count |
| FROM {source_table} |
| WHERE {independent_varname} IS NOT NULL |
| AND NOT {schema_madlib}.array_contains_null( |
| {independent_varname}) |
| AND {dependent_varname} IS NOT NULL |
| GROUP BY {grouping_col} |
| ) subq |
| """.format(**locals()))[0]['counts'] |
| if any(c != len(category_list) for c in grouped_category_counts): |
| plpy.error("Multinom error: Categories are not consistent across " |
| "all groups!") |
| |
| return category_list |
| # ======================================================================== |
| |
| |
| def __multinom_compute(schema_madlib, tbl_source, tbl_output, col_dep_var, |
| col_ind_var, ref_category, category_list, link_func, |
| grouping_col, optim_params_dict, verbose): |
| |
| old_msg_level = plpy.execute(""" |
| SELECT setting FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| if verbose: |
| plpy.execute("SET client_min_messages TO info") |
| else: |
| plpy.execute("SET client_min_messages TO warning") |
| |
| args = {'schema_madlib': schema_madlib, |
| 'rel_source': tbl_source, |
| 'tbl_output': tbl_output, |
| 'rel_state': unique_string(), |
| 'col_dep_var': col_dep_var, |
| 'col_ind_var': col_ind_var, |
| 'col_grp_iteration': unique_string(), |
| 'col_grp_state': unique_string(), |
| 'col_n_tuples': unique_string(), |
| 'ref_category': ref_category, |
| 'n_categories': len(category_list), |
| 'link': link_func, |
| 'state_type': schema_madlib + ".bytea8" |
| } |
| args.update(optim_params_dict) |
| |
| # return an array of dict |
| # each dict has two elements: iteration number, and grouping value array |
| if grouping_col: |
| grouping_list = explicit_bool_to_text( |
| tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib) |
| for i in range(len(grouping_list)): |
| grouping_list[i] += "::text" |
| grouping_str = ','.join(grouping_list) |
| else: |
| grouping_col = None |
| grouping_str = "NULL" |
| args['grouping_col'] = grouping_col |
| args['grouping_str'] = grouping_str |
| |
| # build the case when expression to convert category value to integer |
| # when aggregate is called |
| category_expr_tmp = """\n """.join([ |
| "WHEN ({col_dep_var})::text = '{c}' THEN {i}". |
| format(col_dep_var=col_dep_var, c=c, i=i) |
| for i, c in enumerate(category_list)]) |
| args['category_expr'] = "CASE " + category_expr_tmp + "\nEND" |
| |
| # REAL COMPUTATION # |
| iteration_run = __compute_multinom(args) |
| |
| # output table |
| grouping_str1 = "" if grouping_col is None else grouping_col + "," |
| grouping_str2 = "1 = 1" if grouping_col is None else grouping_col |
| using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")" |
| join_str = "," if grouping_str1 == "" else "join " |
| glm_result = "__multinom_result" |
| args['category_str'] = ','.join([c for c in category_list]) |
| |
| q_out_table = """ |
| DROP TABLE IF EXISTS {tbl_output}; |
| CREATE TABLE {tbl_output} AS |
| SELECT |
| {grouping_str1} |
| category_list[index+1] AS category, |
| {schema_madlib}.index_2d_array((result).coef, index) AS coef, |
| (result).loglik AS log_likelihood, |
| {schema_madlib}.index_2d_array((result).std_err, index) AS std_err, |
| {schema_madlib}.index_2d_array((result).z_stats, index) AS z_stats, |
| {schema_madlib}.index_2d_array((result).p_values, index) AS p_values, |
| (CASE WHEN result IS NULL THEN 0 |
| ELSE (result).num_rows_processed |
| END)::bigint AS num_rows_processed, |
| (CASE WHEN result IS NULL THEN num_rows |
| ELSE num_rows - (result).num_rows_processed |
| END)::bigint AS num_rows_skipped, |
| {col_grp_iteration}::integer AS num_iterations |
| FROM |
| ( |
| SELECT |
| {col_grp_iteration}, {grouping_str1} result, num_rows |
| FROM |
| ( |
| ( |
| SELECT |
| {grouping_str1} |
| {schema_madlib}.__multinom_result( |
| {col_grp_state}) AS result, |
| {col_grp_iteration} |
| FROM |
| {rel_state} |
| ) t |
| JOIN |
| ( |
| SELECT |
| {grouping_str1} |
| max({col_grp_iteration}) AS {col_grp_iteration} |
| FROM {rel_state} |
| GROUP BY {grouping_str2} |
| ) s |
| USING ({grouping_str1} {col_grp_iteration}) |
| ) q1 |
| {join_str} |
| ( |
| SELECT |
| {grouping_str1} |
| count(*) AS num_rows |
| FROM {rel_source} |
| GROUP BY {grouping_str2} |
| ) q2 |
| {using_str} |
| ) q3, |
| ( |
| SELECT generate_series(1, {n_categories}-1) AS index |
| ) q4, |
| ( |
| SELECT '{{{category_str}}}'::varchar[] AS category_list |
| ) q5 |
| """.format(grouping_str1=grouping_str1, |
| grouping_str2=grouping_str2, |
| iteration_run=iteration_run, |
| using_str=using_str, |
| join_str=join_str, |
| glm_result=glm_result, |
| **args) |
| # plpy.info(q_out_table) |
| plpy.execute(q_out_table) |
| |
| # summary table |
| failed_groups = plpy.execute(""" |
| SELECT count(*) AS num_failed_groups |
| FROM {tbl_output} |
| WHERE coef IS NULL |
| """.format(**args))[0] |
| all_groups = plpy.execute(""" |
| SELECT count(*) AS num_all_groups |
| FROM {tbl_output} |
| """.format(**args))[0] |
| total_rows = plpy.execute(""" |
| SELECT |
| sum(num_rows_processed) AS total_rows_processed, |
| sum(num_rows_skipped) AS total_rows_skipped |
| FROM {tbl_output} |
| """.format(tbl_output=tbl_output))[0] |
| |
| args.update(failed_groups) |
| args.update(all_groups) |
| args.update(total_rows) |
| |
| tbl_output_summary = add_postfix(tbl_output, "_summary") |
| plpy.execute(""" |
| CREATE TABLE {tbl_output_summary} AS |
| SELECT |
| 'multinom'::varchar AS method, |
| '{rel_source}'::varchar AS source_table, |
| '{tbl_output}'::varchar AS out_table, |
| $madlib_super_quote${col_dep_var}$madlib_super_quote$::varchar |
| AS dependent_varname, |
| $madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar |
| AS independent_varname, |
| '{ref_category}'::varchar AS ref_category, |
| '{category_str}'::varchar AS category_list, |
| '{link}'::varchar AS link_func, |
| {g_str}::varchar AS grouping_col, |
| 'optimizer={optimizer}, ' || |
| 'max_iter={max_iter}, ' || |
| 'tolerance={tolerance}'::varchar AS optimizer_params, |
| {num_all_groups}::integer AS num_all_groups, |
| {num_failed_groups}::integer AS num_failed_groups, |
| {total_rows_processed}::bigint AS total_rows_processed, |
| {total_rows_skipped}::bigint AS total_rows_skipped |
| """.format(g_str="'" + grouping_col + "'" if grouping_col else "NULL", |
| tbl_output_summary=tbl_output_summary, |
| **args)) |
| |
| # clean up |
| plpy.execute(""" |
| DROP TABLE IF EXISTS pg_temp.{rel_state} |
| """.format(**args)) |
| plpy.execute("SET client_min_messages TO " + old_msg_level) |
| |
| return None |
| # ======================================================================== |
| |
| |
| def multinom_help_msg(schema_madlib, message, **kwargs): |
| """ Help message for multinomial linear regression model |
| |
| @param message A string, the help message indicator |
| |
| Returns: |
| A string, contains the help message |
| """ |
| if not message: |
| |
| help_string = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Multinomial Linear Model: |
| |
| Currently only logit link functions are supported. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.multinom('usage') |
| """ |
| elif message in ['usage', 'help', '?']: |
| |
| help_string = """ |
| ------------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------------ |
| SELECT {schema_madlib}.multinom( |
| source_table, -- name of input table |
| model_table, -- name of model table |
| dependent_varname, -- name of dependent variable |
| independent_varname, -- names of independent variables |
| ref_category, -- optional, parameter for reference category |
| link_func, -- optional, parameter for link function |
| grouping_col, -- optional, default NULL, names of columns to group-by |
| optim_params, -- optional, parameters for optimizer |
| verbose -- optional, default FALSE, whether to print debug info |
| ); |
| |
| ------------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------------ |
| The output table ('out_table' above) has the following columns: |
| <...> -- grouping columns |
| 'category' varchar, -- category value |
| 'coef' double precision[], -- vector of coefficients |
| 'log_likelihood' double precision, -- log likelihood |
| 'std_err' double precision[], -- vector of standard errors |
| 'z_stats' double precision[], -- vector of z-statistics |
| 'p_values' double precision[], -- vector of p-values |
| 'num_rows_processed' bigint, -- numbers of rows processed |
| 'num_rows_skipped' bigint, -- numbers of rows skipped |
| 'num_iterations' integer -- number of iterations run |
| |
| A summary table named <out_table>_summary is also created at the same time, which has: |
| method varchar, -- modeling method name: 'multinom' |
| source_table varchar, -- the data source table name |
| model_table varchar, -- the output table name |
| dependent_varname varchar, -- the dependent variable |
| independent_varname varchar, -- the independent variable |
| ref_category varchar, -- reference category value |
| category_list varchar, -- all categories used for training |
| link_func varchar, -- link function |
| grouping_col varchar -- grouping columns used in the regression |
| optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...' |
| num_all_groups integer, -- how many groups |
| num_failed_groups integer, -- how many groups' fitting processes failed |
| total_rows_processed bigint, -- total numbers of rows processed |
| total_rows_skipped bigint, -- total numbers of rows skipped |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.multinom('help')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| |
| # =============================================================================== |
| # Multinomial prediction function |
| # =============================================================================== |
| |
| |
| def multinom_predict(schema_madlib, model_table, predict_table, |
| predicted_value_tab, predict_type, verbose, |
| id_column, **kwargs): |
| """ |
| Compute the predicted value for multinomial regresssion |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param model_table Name of table containing training result from |
| multinom() |
| @param predict_table Name of table containing new data to predict |
| @param predicted_value_tab Name of table to output the predict value |
| @param predict_type Type of predict value: 'response' or 'probabilities' |
| @param verbose whether the verbose is displayed |
| @param id_column Name of ID column in the input table |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| """ |
| |
| # Validate the argument |
| input_tbl_valid(model_table, 'multinom_predict') |
| input_tbl_valid(predict_table, 'multinom_predict') |
| output_tbl_valid(predicted_value_tab, 'multinom_predict') |
| |
| if verbose: |
| plpy.execute("SET client_min_messages TO info") |
| else: |
| plpy.execute("SET client_min_messages TO warning") |
| |
| if predict_type is None: |
| predict_type = 'response' |
| |
| model_table_summary = add_postfix(model_table, "_summary") |
| ref_category = plpy.execute(""" |
| SELECT ref_category FROM {model_table_summary} |
| """.format(**locals()))[0]['ref_category'] |
| ind_var = plpy.execute(""" |
| SELECT independent_varname FROM {model_table_summary} |
| """.format(**locals()))[0]['independent_varname'] |
| cate_list = plpy.execute(""" |
| SELECT category_list FROM {model_table_summary} |
| """.format(**locals()))[0]['category_list'] |
| cate_list = cate_list.split(',') |
| |
| group_var = plpy.execute(""" |
| SELECT grouping_col FROM {model_table_summary} |
| """.format(**locals()))[0]['grouping_col'] |
| if group_var is not None: |
| cols_in_tbl_valid(predict_table, _string_to_array(group_var), 'multinom_predict') |
| group_var = group_var.split(',') |
| |
| if group_var is None: |
| grp_clause = "" |
| else: |
| grp_clause = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var]) |
| grp_clause = "WHERE " + grp_clause |
| |
| if columns_exist_in_table(predict_table, [id_column], schema_madlib): |
| multinom_predict_id = id_column |
| else: |
| multinom_predict_id = 'multinom_predict_id' |
| |
| if predict_type == 'response': |
| sql = """ |
| CREATE TABLE {predicted_value_tab} AS |
| SELECT |
| subq2.{multinom_predict_id}, |
| subq3.category AS category |
| FROM |
| ( |
| SELECT |
| greatest(0, max_score) AS max_score, |
| {multinom_predict_id} |
| FROM |
| ( |
| SELECT |
| max( |
| {schema_madlib}.array_dot(coef, {ind_var}::float8[]) |
| ) AS max_score, |
| {id} AS {multinom_predict_id} |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause} |
| GROUP BY {id} |
| ) subq |
| ) subq2 |
| LEFT JOIN |
| ( |
| SELECT |
| {schema_madlib}.array_dot(coef, {ind_var}::float8[]) AS score, |
| {id} AS {multinom_predict_id}, |
| category::TEXT |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause} |
| UNION |
| SELECT |
| 0 AS score, |
| {id} AS {multinom_predict_id}, |
| '{ref_category}' AS category |
| FROM {predict_table} |
| ) subq3 |
| ON |
| ( |
| subq2.{multinom_predict_id} = subq3.{multinom_predict_id} |
| AND |
| subq2.max_score=subq3.score |
| ) |
| ORDER BY subq2.{multinom_predict_id}; |
| """.format(id=id_column,**locals()) |
| plpy.notice(sql) |
| plpy.execute(sql) |
| elif predict_type in ('probability', 'prob'): |
| score_format = '\n'.join([ |
| ",score_arr[{j}] as \"{c}\"". |
| format(j=i+1, c=c) |
| for i, c in enumerate(cate_list)]) |
| |
| score_map = '\n'.join([ |
| "WHEN category='{c}' THEN {j}". |
| format(j=i+1, c=c) |
| for i, c in enumerate(cate_list)]) |
| |
| sql = """ |
| CREATE TABLE {predicted_value_tab} AS |
| SELECT |
| {multinom_predict_id} |
| {score_format} |
| FROM |
| ( |
| SELECT |
| {schema_madlib}.array_scalar_mult( |
| array_agg(score ORDER BY idx), |
| 1. / {schema_madlib}.array_sum( |
| array_agg(score)::float8[] |
| ) |
| ) AS score_arr, |
| array_agg(category ORDER BY idx) AS cate_arr, |
| {multinom_predict_id} |
| FROM |
| ( |
| SELECT |
| score, |
| {multinom_predict_id}, |
| subq2.category, |
| idx |
| FROM |
| ( |
| SELECT |
| exp({schema_madlib}.array_dot( |
| coef, |
| {ind_var}::float8[] |
| ) |
| ) AS score, |
| category, |
| {id} AS {multinom_predict_id} |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause} |
| UNION |
| SELECT |
| 1. AS score, |
| '{ref_category}' AS category, |
| {id} AS {multinom_predict_id} |
| FROM {predict_table} |
| ) subq2 |
| LEFT JOIN |
| ( |
| SELECT |
| category, |
| CASE {score_map} END as idx |
| FROM |
| ( |
| SELECT unnest(ARRAY{cate_list}) AS category |
| ) subq |
| ) subq3 |
| ON (subq2.category = subq3.category) |
| ORDER BY {multinom_predict_id}, idx |
| ) subq4 |
| GROUP BY {multinom_predict_id} |
| ) subq5 |
| ORDER by {multinom_predict_id}; |
| """.format(id=id_column,**locals()) |
| plpy.notice(sql) |
| plpy.execute(sql) |
| else: |
| plpy.error("Invalid prediction type!\n") |
| |
| return None |
| |
| |
| # ========== help message for the prediction function ====================== |
| def multinom_predict_help_msg(schema_madlib, message, **kwargs): |
| """ Help message for prediction function for multinomial regression |
| |
| @param message A string, the help message indicator |
| |
| Returns: |
| A string, contains the help message |
| """ |
| if not message: |
| |
| help_string = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Prediction function for multinomial linear regression: |
| |
| Estimate the conditional probility or give the response category given |
| a new set of predictors. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.multinom_predict('usage') |
| """ |
| elif message in ['usage', 'help', '?']: |
| |
| help_string = """ |
| ------------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------------ |
| SELECT {schema_madlib}.multinom_predict( |
| model_table, -- Name of the table containing the output of multinom() |
| predict_table_input, -- Name of the table containing new data |
| output_table, -- Name of the table storing the result of predicted values |
| predict_type, -- Support two types: "reponse" or "probability" |
| verbose, -- Whether verbose is diplayed, default is FALSE |
| id_column -- Name of the id column in the input table |
| ); |
| |
| ------------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------------ |
| The output is a table with one column which gives the predicted category when predict_type |
| is response and probability when predict_type is probability. |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.multinom_predict('help')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |