| # coding=utf-8 |
| |
| """ |
| @file ordinal.py_in |
| |
| @brief Ordinal regression: Driver functions |
| |
| @namespace glm |
| |
| @brief Ordinal Linear Models: Driver functions |
| """ |
| import plpy |
| |
| from utilities.in_mem_group_control import GroupIterationController |
| from utilities.utilities import unique_string |
| from utilities.validate_args import explicit_bool_to_text |
| from utilities.utilities import _string_to_array |
| from utilities.utilities import _string_to_array_with_quotes |
| from utilities.utilities import add_postfix |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| |
| from glm import __glm_validate_args |
| from glm import __extract_optim_params |
| |
| # ======================================================================== |
| |
| |
| def __compute_ordinal(arg_dict): |
| """ |
| Compute Ordinal Model coefficients |
| |
| This method serves as an interface to different optimization algorithms. |
| By default, iteratively reweighted least squares is used. |
| |
| @return Number of iterations that has been run |
| """ |
| |
| iterationCtrl = GroupIterationController(arg_dict) |
| |
| with iterationCtrl as it: |
| it.iteration = 0 |
| while True: |
| it.update( |
| """ |
| {schema_madlib}.__ordinal_{link}_agg( |
| ({category_expr})::integer, |
| ({col_ind_var})::double precision[], |
| {rel_state}.{col_grp_state}, |
| {n_categories}::smallint) |
| """) |
| if it.test( |
| """ |
| {iteration} >= {max_iter} |
| OR {schema_madlib}.__ordinal_loglik_diff( |
| _state_previous, _state_current) < {tolerance} |
| """): |
| it.final() |
| break |
| |
| return iterationCtrl.iteration |
| |
| # ======================================================================== |
| |
| |
| def ordinal(schema_madlib, source_table, model_table, |
| dependent_varname, independent_varname, cat_order, link_func, |
| grouping_col, optim_params, verbose, **kwargs): |
| |
| category_list = __ordinal_validate_args( |
| schema_madlib, source_table, model_table, dependent_varname, |
| independent_varname, cat_order, link_func, grouping_col) |
| |
| # default values |
| link_func = 'logit' if link_func is None else link_func |
| optim_params = '' if optim_params is None else optim_params |
| |
| optim_params_dict = __extract_optim_params(schema_madlib, |
| optim_params, |
| 'Ordinal') |
| |
| return __ordinal_compute( |
| schema_madlib, source_table, model_table, dependent_varname, |
| independent_varname, category_list, link_func, grouping_col, |
| optim_params_dict, verbose) |
| |
| # ======================================================================== |
| |
| |
| def __ordinal_validate_args( |
| schema_madlib, source_table, model_table, dependent_varname, |
| independent_varname, cat_order, link_func, grouping_col): |
| |
| __glm_validate_args(schema_madlib, source_table, model_table, |
| dependent_varname, independent_varname, grouping_col) |
| |
| if link_func is not None and link_func not in ('logit', 'probit'): |
| plpy.error("Ordinal error: Invalid link function!\n" |
| "Only 'logit' and 'probit' are supported.") |
| |
| category_list = plpy.execute(""" |
| SELECT array_agg(category ORDER BY category) AS category_list |
| FROM ( |
| SELECT distinct {dependent_varname} AS category |
| FROM {source_table} |
| WHERE {independent_varname} IS NOT NULL |
| AND NOT {schema_madlib}.array_contains_null( |
| {independent_varname}) |
| AND {dependent_varname} IS NOT NULL |
| ) subq |
| """.format(**locals()))[0]['category_list'] |
| |
| if len(category_list) == 0: |
| plpy.error("Ordinal error: No non-null categories found!") |
| if len(category_list) == 1: |
| plpy.error("Ordinal error: Only a single valid category found!") |
| |
| if not (isinstance(category_list[0], int) or |
| isinstance(category_list[0], float) or |
| isinstance(category_list[0], long) or |
| isinstance(category_list[0], str)): |
| plpy.error("Ordinal error: Given category type is not supported!\n" |
| "Only numeric, character, binary data and enumerated types " |
| "are supported. Particularly, if the category type is boolean," |
| "please use glm() binomial family instead.") |
| |
| category_list = [str(c) for c in category_list] |
| |
| if cat_order is None: |
| order_list = sorted(category_list) |
| else: |
| order_list = [c.strip() for c in cat_order.split('<')] |
| |
| if len(order_list)!=len(category_list): |
| plpy.error("Ordinal error: category order specification is not valid!") |
| |
| for c in order_list: |
| if category_list.count(c) == 0: |
| plpy.error("Ordinal error: '{c}' is not found in source table {source_table}." |
| .format(**locals())) |
| |
| category_list = order_list |
| |
| if grouping_col: |
| grouped_category_counts = plpy.execute(""" |
| SELECT array_agg(category_count) AS counts |
| FROM ( |
| SELECT count(distinct {dependent_varname}) AS category_count |
| FROM {source_table} |
| WHERE {independent_varname} IS NOT NULL |
| AND NOT {schema_madlib}.array_contains_null( |
| {independent_varname}) |
| AND {dependent_varname} IS NOT NULL |
| GROUP BY {grouping_col} |
| ) subq |
| """.format(**locals()))[0]['counts'] |
| if any(c != len(category_list) for c in grouped_category_counts): |
| plpy.error("Ordinal error: Categories are not consistent across " |
| "all groups!") |
| |
| return category_list |
| |
| # ======================================================================== |
| |
| |
| def __ordinal_compute( |
| schema_madlib, tbl_source, tbl_output, col_dep_var, col_ind_var, category_list, |
| link_func, grouping_col, optim_params_dict, verbose): |
| |
| old_msg_level = plpy.execute(""" |
| SELECT setting FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| if verbose: |
| plpy.execute("SET client_min_messages TO info") |
| else: |
| plpy.execute("SET client_min_messages TO warning") |
| |
| args = {'schema_madlib': schema_madlib, |
| 'rel_source': tbl_source, |
| 'tbl_output': tbl_output, |
| 'col_dep_var': col_dep_var, |
| 'col_ind_var': col_ind_var, |
| 'rel_state': unique_string(), |
| 'col_grp_iteration': unique_string(), |
| 'col_grp_state': unique_string(), |
| 'state_type': schema_madlib + ".bytea8", |
| 'n_categories': len(category_list), |
| 'link': link_func} |
| |
| args.update(optim_params_dict) |
| |
| # return an array of dict |
| # each dict has two elements: iteration number, and grouping value array |
| if grouping_col: |
| grouping_list = explicit_bool_to_text( |
| tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib) |
| for i in range(len(grouping_list)): |
| grouping_list[i] += "::text" |
| grouping_str = ','.join(grouping_list) |
| else: |
| grouping_col = None |
| grouping_str = "NULL" |
| args['grouping_col'] = grouping_col |
| args['grouping_str'] = grouping_str |
| |
| # build the case when expression to convert category value to integer |
| # when aggregate is called |
| category_expr_tmp = '\n'.join([ |
| "WHEN {d}::text = '{c}' THEN {i}".format(d=col_dep_var, c=c, i=i) |
| for i, c in enumerate(category_list)]) |
| args['category_expr'] = "CASE " + category_expr_tmp + "\nEND" |
| |
| iteration_run = __compute_ordinal(args) |
| |
| # output table |
| grouping_str1 = "" if grouping_col is None else grouping_col + "," |
| grouping_str2 = "1 = 1" if grouping_col is None else grouping_col |
| using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")" |
| join_str = "," if grouping_str1 == "" else "join " |
| category_str = ','.join([c for c in category_list]) |
| args['category_str'] = category_str |
| |
| q_out_table = """ |
| DROP TABLE IF EXISTS {tbl_output}; |
| CREATE TABLE {tbl_output} AS |
| SELECT |
| {grouping_str1} |
| (result).coef_alpha AS coef_threshold, |
| (result).std_err_alpha AS std_err_threshold, |
| (result).z_stats_alpha AS z_stats_threshold, |
| (result).p_values_alpha AS p_values_threshold, |
| (result).loglik AS log_likelihood, |
| (result).coef_beta AS coef_feature, |
| (result).std_err_beta AS std_err_feature, |
| (result).z_stats_beta AS z_stats_feature, |
| (result).p_values_beta AS p_values_feature, |
| (CASE WHEN result IS NULL THEN 0 |
| ELSE (result).num_rows_processed |
| END)::bigint AS num_rows_processed, |
| (CASE WHEN result IS NULL THEN num_rows |
| ELSE num_rows - (result).num_rows_processed |
| END)::bigint AS num_rows_skipped, |
| {col_grp_iteration}::integer AS num_iterations |
| FROM |
| ( |
| SELECT |
| {col_grp_iteration}, {grouping_str1} result, num_rows |
| FROM |
| ( |
| ( |
| SELECT |
| {grouping_str1} |
| {schema_madlib}.__ordinal_result({col_grp_state}) AS result, |
| {col_grp_iteration} |
| FROM |
| {rel_state} |
| ) t |
| JOIN |
| ( |
| SELECT |
| {grouping_str1} |
| max({col_grp_iteration}) AS {col_grp_iteration} |
| FROM {rel_state} |
| GROUP BY {grouping_str2} |
| ) s |
| USING ({grouping_str1} {col_grp_iteration}) |
| ) q1 |
| {join_str} |
| ( |
| SELECT |
| {grouping_str1} |
| count(*) AS num_rows |
| FROM {rel_source} |
| GROUP BY {grouping_str2} |
| ) q2 |
| {using_str} |
| ) q3 |
| """.format(grouping_str1=grouping_str1, |
| grouping_str2=grouping_str2, |
| iteration_run=iteration_run, |
| using_str=using_str, |
| join_str=join_str, |
| **args) |
| # plpy.info(q_out_table) |
| plpy.execute(q_out_table) |
| |
| # summary table |
| failed_groups = plpy.execute(""" |
| SELECT count(*) AS num_failed_groups |
| FROM {tbl_output} |
| WHERE coef_threshold IS NULL |
| """.format(**args))[0] |
| all_groups = plpy.execute(""" |
| SELECT count(*) AS num_all_groups |
| FROM {tbl_output} |
| """.format(**args))[0] |
| total_rows = plpy.execute(""" |
| SELECT |
| sum(num_rows_processed) AS total_rows_processed, |
| sum(num_rows_skipped) AS total_rows_skipped |
| FROM {tbl_output} |
| """.format(tbl_output=tbl_output))[0] |
| |
| args.update(failed_groups) |
| args.update(all_groups) |
| args.update(total_rows) |
| |
| tbl_output_summary = add_postfix(tbl_output, "_summary") |
| plpy.execute(""" |
| CREATE TABLE {tbl_output_summary} AS |
| SELECT |
| 'ordinal'::varchar AS method, |
| '{rel_source}'::varchar AS source_table, |
| '{tbl_output}'::varchar AS out_table, |
| $msq${col_dep_var}$msq$::varchar AS dependent_varname, |
| $msq${col_ind_var}$msq$::varchar AS independent_varname, |
| '{category_str}'::varchar AS category_list, |
| '{link}'::varchar AS link_func, |
| {g_str}::varchar AS grouping_col, |
| 'optimizer={optimizer}, ' || |
| 'max_iter={max_iter}, ' || |
| 'tolerance={tolerance}'::varchar AS optimizer_params, |
| {num_all_groups}::integer AS num_all_groups, |
| {num_failed_groups}::integer AS num_failed_groups, |
| {total_rows_processed}::bigint AS total_rows_processed, |
| {total_rows_skipped}::bigint AS total_rows_skipped |
| """.format(g_str="'" + grouping_col + "'" if grouping_col else "NULL", |
| tbl_output_summary=tbl_output_summary, |
| **args)) |
| |
| # clean up |
| plpy.execute(""" |
| DROP TABLE IF EXISTS pg_temp.{rel_state} |
| """.format(**args)) |
| plpy.execute("SET client_min_messages TO " + old_msg_level) |
| |
| return None |
| # ======================================================================== |
| |
| |
| def ordinal_help_msg(schema_madlib, message, **kwargs): |
| """ Help message for ordinalial linear regression model |
| |
| @param message A string, the help message indicator |
| |
| Returns: |
| A string, contains the help message |
| """ |
| if not message: |
| |
| help_string = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Ordinal Linear Model: |
| |
| Currently logit and probit link functions are supported. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.ordinal('usage') |
| |
| For a small example on using the function: |
| SELECT {schema_madlib}.ordinal('example') |
| """ |
| elif message in ['usage', 'help', '?']: |
| |
| help_string = """ |
| ------------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------------ |
| SELECT {schema_madlib}.ordinal( |
| source_table, -- name of input table |
| model_table, -- name of model table |
| dependent_varname, -- name of dependent variable |
| independent_varname, -- names of independent variables |
| cat_order, -- category order specified by '<' |
| link_func, -- optional, parameter for link function |
| grouping_col, -- optional, default NULL, names of columns to group-by |
| optim_params, -- optional, parameters for optimizer |
| verbose -- optional, default FALSE, whether to print debug info |
| ); |
| |
| ------------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------------ |
| The output table ('out_table' above) has the following columns: |
| <...> -- grouping columns |
| 'coef_threshold' double precision[], -- vector of threshold coefficients |
| 'std_err_threshold' double precision[], -- vector of standard errors for threshold coefficients |
| 'z_stats_threshold' double precision[], -- vector of z-statistics for threshold coefficients |
| 'p_values_threshold' double precision[], -- vector of p-values for threshold coefficients |
| 'log_likelihood' double precision, -- log likelihood |
| 'coef_feature' double precision[], -- vector of feature coefficients |
| 'std_err_feature' double precision[], -- vector of standard errors for feature coefficients |
| 'z_stats_feature' double precision[], -- vector of z-statistics for feature coefficients |
| 'p_values_feature' double precision[], -- vector of p-values for feature coefficients |
| 'num_rows_processed' bigint, -- numbers of rows processed |
| 'num_rows_skipped' bigint, -- numbers of rows skipped |
| 'num_iterations' integer -- number of iterations run |
| |
| A summary table named <out_table>_summary is also created at the same time, which has: |
| method varchar, -- modeling method name: 'ordinal' |
| source_table varchar, -- the data source table name |
| model_table varchar, -- the output table name |
| dependent_varname varchar, -- the dependent variable |
| independent_varname varchar, -- the independent variable |
| category_list varchar, -- ordered categories used for training |
| link_func varchar, -- link function: logit and probit supported |
| grouping_col varchar -- grouping columns used in the regression |
| optimizer_params varchar, -- 'optimizer=...,max_iter=...,tolerance=...' |
| num_all_groups integer, -- how many groups |
| num_failed_groups integer, -- how many groups' fitting processes failed |
| total_rows_processed bigint, -- total numbers of rows processed |
| total_rows_skipped bigint, -- total numbers of rows skipped |
| """ |
| |
| elif message in ['example', 'examples']: |
| |
| help_string = """ |
| |
| DROP TABLE IF EXISTS test3; |
| CREATE TABLE test3 ( |
| feat1 INTEGER, |
| feat2 INTEGER, |
| cat INTEGER |
| ); |
| INSERT INTO test3(feat1, feat2, cat) VALUES |
| (1,35,1), |
| (2,33,0), |
| (3,39,1), |
| (1,37,1), |
| (2,31,1), |
| (3,36,0), |
| (2,36,1), |
| (2,31,1), |
| (2,41,1), |
| (2,37,1), |
| (1,44,1), |
| (3,33,2), |
| (1,31,1), |
| (2,44,1), |
| (1,35,1), |
| (1,44,0), |
| (1,46,0), |
| (2,46,1), |
| (2,46,2), |
| (3,49,1), |
| (2,39,0), |
| (2,44,1), |
| (1,47,1), |
| (1,44,1), |
| (1,37,2), |
| (3,38,2), |
| (1,49,0), |
| (2,44,0), |
| (3,61,2), |
| (1,65,2), |
| (3,67,1), |
| (3,65,2), |
| (1,65,2), |
| (2,67,2), |
| (1,65,2), |
| (1,62,2), |
| (3,52,2), |
| (3,63,2), |
| (2,59,2), |
| (3,65,2), |
| (2,59,0), |
| (3,67,2), |
| (3,67,2), |
| (3,60,2), |
| (3,67,2), |
| (3,62,2), |
| (2,54,2), |
| (3,65,2), |
| (3,62,2), |
| (2,59,2), |
| (3,60,2), |
| (3,63,2), |
| (3,65,2), |
| (2,63,1), |
| (2,67,2), |
| (2,65,2), |
| (2,62,2); |
| |
| -- Run the ordinal logistic regression function. |
| DROP TABLE IF EXISTS test3_output; |
| DROP TABLE IF EXISTS test3_output_summary; |
| SELECT madlib.ordinal('test3', |
| 'test3_output', |
| 'cat', |
| 'ARRAY[feat1, feat2]', |
| '0<1<2', |
| 'logit' |
| ); |
| |
| SELECT * from test3_output; |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.ordinal('help')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| |
| # =============================================================================== |
| # Ordinal prediction function |
| # =============================================================================== |
| |
| |
| def ordinal_predict(schema_madlib, model_table, predict_table, |
| predicted_value_tab, predict_type, verbose, **kwargs): |
| """ |
| Compute the predicted value for ordinalial regresssion |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param model_table Name of table containing training result from |
| ordinal() |
| @param predict_table Name of table containing new data to predict |
| @param predicted_value_tab Name of table to output the predict value |
| @param predict_type Type of predict value: 'response' or 'probabilities' |
| @param verbose Whether verbose will be displayed |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| """ |
| |
| # Validate the argument |
| input_tbl_valid(model_table, 'ordinal_predict') |
| input_tbl_valid(predict_table, 'ordinal_predict') |
| output_tbl_valid(predicted_value_tab, 'ordinal_predict') |
| cols_in_tbl_valid(predict_table, _string_to_array("id"), 'ordinali_predict') |
| |
| if verbose: |
| plpy.execute("SET client_min_messages TO info") |
| else: |
| plpy.execute("SET client_min_messages TO warning") |
| |
| if predict_type is None: |
| predict_type = 'probability' |
| |
| model_table_summary = add_postfix(model_table, "_summary") |
| ind_var = plpy.execute(""" |
| SELECT independent_varname FROM {model_table_summary} |
| """.format(model_table_summary=model_table_summary))[0]['independent_varname'] |
| cate_list = plpy.execute(""" |
| SELECT category_list FROM {model_table_summary} |
| """.format(model_table_summary=model_table_summary))[0]['category_list'] |
| cate_list = cate_list.split(',') |
| cate_list_len = len(cate_list) |
| cate_list_len_minus_one = cate_list_len - 1 |
| |
| group_var = plpy.execute(""" |
| SELECT grouping_col FROM {model_table_summary} |
| """.format(model_table_summary=model_table_summary))[0]['grouping_col'] |
| if group_var is not None: |
| cols_in_tbl_valid(predict_table, _string_to_array_with_quotes(group_var), 'ordinal_predict') |
| group_var = group_var.split(',') |
| |
| if group_var is None: |
| grp_clause1 = "" |
| grp_clause2 = "" |
| grp_clause3 = "" |
| else: |
| grp_clause1 = " AND ".join(["{predict_table}.{c} = {model_table}.{c}".format(c=c, predict_table=predict_table, model_table=model_table) for c in group_var]) |
| grp_clause1 = "WHERE " + grp_clause1 |
| grp_clause2 = ", ".join(["{model_table}.{c} as {c}".format(c=c, model_table=model_table) for c in group_var]) |
| grp_clause2 = ", " + grp_clause2 |
| grp_clause3 = " AND ".join(["subq2.{c} = subq3.{c}".format(c=c) for c in group_var]) |
| grp_clause3 = "WHERE " + grp_clause3 |
| |
| link_func = plpy.execute(""" |
| SELECT link_func FROM {model_table_summary} |
| """.format(model_table_summary=model_table_summary))[0]['link_func'] |
| if link_func == "logit": |
| link_clause = "exp(gamma)/(1+exp(gamma))" |
| elif link_func == "probit": |
| link_clause = "{schema_madlib}.normal_cdf(gamma)".format(schema_madlib=schema_madlib) |
| else: |
| plpy.error("Invalid link function!\n") |
| |
| if predict_type == 'probability': |
| score_format = '\n'.join([ |
| "indpr[{j}] as \"{c}\",". |
| format(j=i+1, c=c) |
| for i, c in enumerate(cate_list)]) |
| plpy.execute(""" |
| CREATE TABLE {predicted_value_tab} AS |
| SELECT {score_format} |
| id |
| FROM |
| ( |
| SELECT |
| id, |
| {schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]), |
| array_cat(ARRAY[0]::float8[],prarray)) as indpr, |
| array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| array_agg(cumpr ORDER BY idx) as prarray, |
| array_agg(category ORDER BY idx) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| {link_clause} as cumpr, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| (alpha-xbeta) as gamma, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| ({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta |
| {grp_clause2} |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause1} |
| )subq2, |
| ( |
| SELECT |
| i as idx, |
| coef_threshold[i] as alpha, |
| (string_to_array(category_list,','))[i] as category |
| {grp_clause2} |
| FROM |
| {model_table}, |
| {model_table_summary}, |
| ( |
| SELECT |
| generate_series(1,{cate_list_len_minus_one}) as i |
| )subq1 |
| )subq3 |
| {grp_clause3} |
| )subq4 |
| )subq5 |
| GROUP by id)subq6, {model_table_summary} |
| )subq7 |
| """.format(**locals())) |
| elif predict_type == 'response': |
| plpy.execute(""" |
| CREATE TABLE {predicted_value_tab} AS |
| SELECT |
| subq8.id as id, |
| subq8.category as category |
| FROM |
| ( |
| SELECT |
| id, |
| max(prob) as max_prob |
| FROM |
| ( |
| SELECT |
| id, |
| unnest(indpr) as prob, |
| unnest(catearray) as category |
| FROM |
| ( |
| SELECT |
| id, |
| {schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]), |
| array_cat(ARRAY[0]::float8[],prarray)) as indpr, |
| array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| array_agg(cumpr ORDER BY idx) as prarray, |
| array_agg(category ORDER BY idx) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| {link_clause} as cumpr, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| (alpha-xbeta) as gamma, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| ({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta |
| {grp_clause2} |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause1} |
| )subq2, |
| ( |
| SELECT |
| i as idx, |
| coef_threshold[i] as alpha, |
| (string_to_array(category_list,','))[i] as category |
| {grp_clause2} |
| FROM |
| {model_table}, |
| {model_table_summary}, |
| ( |
| SELECT generate_series(1,{cate_list_len_minus_one}) as i |
| )subq1 |
| )subq3 |
| {grp_clause3} |
| )subq4 |
| )subq5 |
| GROUP by id |
| )subq6, |
| {model_table_summary} |
| )subq7 |
| )subq8 |
| GROUP BY id |
| )subq9, |
| ( |
| SELECT |
| id, |
| unnest(indpr) as prob, |
| unnest(catearray) as category |
| FROM |
| ( |
| SELECT |
| id, |
| {schema_madlib}.array_sub(array_cat(prarray,ARRAY[1]::float8[]), |
| array_cat(ARRAY[0]::float8[],prarray)) as indpr, |
| array_append(catearray, (string_to_array(category_list,','))[{cate_list_len}]) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| array_agg(cumpr ORDER BY idx) as prarray, |
| array_agg(category ORDER BY idx) as catearray |
| FROM |
| ( |
| SELECT |
| id, |
| {link_clause} as cumpr, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| (alpha-xbeta) as gamma, |
| category, |
| idx |
| FROM |
| ( |
| SELECT |
| id, |
| ({schema_madlib}.array_dot(coef_feature, {ind_var}::float8[])) as xbeta |
| {grp_clause2} |
| FROM |
| {predict_table}, |
| {model_table} |
| {grp_clause1} |
| )subq2, |
| ( |
| SELECT |
| i as idx, |
| coef_threshold[i] as alpha, |
| (string_to_array(category_list,','))[i] as category |
| {grp_clause2} |
| FROM |
| {model_table}, |
| {model_table_summary}, |
| ( |
| SELECT generate_series(1,{cate_list_len_minus_one}) as i |
| )subq1 |
| )subq3 |
| {grp_clause3} |
| )subq4 |
| )subq5 |
| GROUP by id |
| )subq6, {model_table_summary} |
| )subq7 |
| )subq8 |
| WHERE subq9.id = subq8.id and subq9.max_prob = subq8.prob |
| """.format(**locals())) |
| else: |
| plpy.error("Invalid prediction type!\n") |
| |
| return None |
| |
| |
| # ========== help message for the prediction function ====================== |
| def ordinal_predict_help_msg(schema_madlib, message, **kwargs): |
| """ Help message for prediction function for ordinal regression |
| |
| @param message A string, the help message indicator |
| |
| Returns: |
| A string, contains the help message |
| """ |
| if not message: |
| |
| help_string = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Prediction function for ordinal linear regression: |
| |
| Estimate the conditional probability or give the response category given |
| a new set of predictors. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.ordinal_predict('usage') |
| |
| For a small example on using the function: |
| SELECT {schema_madlib}.ordinal_predict('example') |
| """ |
| elif message in ['usage', 'help', '?']: |
| |
| help_string = """ |
| ------------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------------ |
| SELECT {schema_madlib}.ordinal_predict( |
| model_table, -- Name of the table containing the output of ordinal() |
| predict_table_input, -- Name of the table containing new data |
| output_table, -- Name of the table storing the result of predicted values |
| predict_type -- Support two types: "reponse" or "probability" |
| verbose -- Whether verbose is displayed |
| ); |
| |
| ------------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------------ |
| The output is a table with one column which gives the predicted category when predict_type |
| is response and probability when predict_type is probability. |
| """ |
| elif message in ['example', 'examples']: |
| help_string = """ |
| -- run the training example first |
| ALTER TABLE test3 ADD COLUMN id SERIAL; |
| DROP TABLE IF EXISTS test3_predict; |
| SELECT ordinal_predict('test3_out', 'test3', 'test3_predict', 'probability'); |
| SELECT * FROM test3_predict; |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.ordinal_predict('help')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |