| # coding=utf-8 |
| |
| """ |
| @file logistic.py_in |
| |
| @brief Logistic Regression: Driver functions |
| |
| @namespace logistic |
| |
| @brief Logistic Regression: Driver functions |
| """ |
| import plpy |
| from utilities.group_control import GroupIterationController |
| from utilities.utilities import __unique_string |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import scalar_col_has_no_null |
| from utilities.utilities import _string_to_array |
| |
| # ======================================================================== |
| |
| def __compute_logregr(schema_madlib, rel_args, rel_state, rel_source, |
| dep_col, ind_col, optimizer, grouping_col, |
| grouping_str, **kwargs): |
| """ |
| Compute logistic regression coefficients |
| |
| This method serves as an interface to different optimization algorithms. |
| By default, iteratively reweighted least squares is used, but for data with |
| a lot of columns the conjugate-gradient method might perform better. |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param rel_args Stores parameters that are needed by optimizers: |
| max_iter and tolerance |
| @param rel_state Store the iteration states |
| @param rel_source Name of relation containing the training data |
| @param dep_col Name of dependent column in training data (of type BOOLEAN) |
| @param ind_col Name of independent column in training data (of type |
| DOUBLE PRECISION[]) |
| @param optimizer Name of the optimizer. 'newton' or 'irls': Iteratively |
| reweighted least squares, 'cg': conjugate gradient or |
| 'igd': incremental gradient descent |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| |
| @return Number of iterations that has been run |
| """ |
| |
| iterationCtrl = GroupIterationController( |
| rel_args=rel_args, |
| rel_state=rel_state, |
| stateType="double precision[]", |
| # truncAfterIteration=False, |
| schema_madlib=schema_madlib, # Identifiers start here |
| rel_source=rel_source, |
| ind_col=ind_col, |
| dep_col=dep_col, |
| optimizer=optimizer, |
| grouping_col = grouping_col, |
| grouping_str = grouping_str) |
| |
| with iterationCtrl as it: |
| it.iteration = 0 |
| while True: |
| it.update( |
| """ |
| {schema_madlib}.__logregr_{optimizer}_step( |
| ({dep_col})::boolean, |
| ({ind_col})::double precision[], |
| {rel_state}._state) |
| """) |
| if it.test( |
| """ |
| {iteration} >= _args.max_iter |
| or |
| {schema_madlib}.__logregr_{optimizer}_step_distance( |
| _state_previous, _state_current) < _args.tolerance |
| """): |
| break |
| return iterationCtrl.iteration |
| |
| # ======================================================================== |
| |
| |
| def logregr_train(schema_madlib, tbl_source, tbl_output, dep_col, ind_col, |
| grouping_col, max_iter, optimizer, tolerance, **kwargs): |
| """ |
| Train logistic model |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param tbl_source Name of relation containing the training data |
| @param tbl_output Name of relation where model will be outputted |
| @param dep_col Name of dependent column in training data (of type BOOLEAN) |
| @param ind_col Name of independent column in training data (of type |
| DOUBLE PRECISION[]) |
| @param grouping_col List of column names on which to group the data |
| @param max_iter The maximum number of iterations that are allowed. |
| @param optimizer Name of the optimizer. 'newton' or 'irls': Iteratively |
| reweighted least squares, 'cg': conjugate gradient or 'igd': |
| incremental gradient descent |
| @param tolerance The precision that the results should have |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| |
| @return A composite value which is __logregr_result defined in logistic.sql_in |
| """ |
| optimizer = __logregr_validate_args(schema_madlib, |
| tbl_source, tbl_output, |
| dep_col, ind_col, |
| grouping_col, max_iter, |
| optimizer, tolerance) |
| |
| return __logregr_train_compute(schema_madlib, tbl_source, tbl_output, |
| dep_col, ind_col, grouping_col, max_iter, |
| optimizer, tolerance, **kwargs) |
| |
| # ======================================================================== |
| |
| |
| def __logregr_validate_args(schema_madlib, tbl_source, tbl_output, dep_col, |
| ind_col, grouping_col, max_iter, optimizer, |
| tolerance): |
| """ |
| Validate the arguments |
| """ |
| if not tbl_source or tbl_source in ('null', '') or \ |
| (not table_exists(tbl_source)): |
| plpy.error("Logregr error: Data table does not exist!") |
| |
| if (table_exists(tbl_output)): |
| plpy.error("Output table name already exists. Drop the table before calling the function.") |
| |
| if table_is_empty(tbl_source): |
| plpy.error("Logregr error: Data table is empty!") |
| |
| if tbl_output.lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid output table name!") |
| |
| # if not columns_exist_in_table(tbl_source, [dep_col]): |
| # plpy.error("Logregr error: Dependent column does not exist!") |
| |
| if not dep_col or dep_col.lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid dependent column name!") |
| |
| if not scalar_col_has_no_null(tbl_source, dep_col): |
| plpy.error("Logregr error: Dependent variable has Null values! \ |
| Please filter out Null values before using this function!") |
| |
| if not ind_col or ind_col.lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid independent column name!") |
| |
| if grouping_col and grouping_col.lower() in ('null', ''): |
| # grouping_col is optional but if provided should be valid column name |
| plpy.error("Logregr error: Invalid grouping columns name!") |
| |
| if grouping_col: |
| if not columns_exist_in_table(tbl_source, |
| _string_to_array(grouping_col), schema_madlib): |
| plpy.error("Logregr error: Grouping column does not exist!") |
| |
| if max_iter <= 0: |
| plpy.error("Logregr error: Maximum number of iterations must be positive!") |
| |
| if tolerance < 0: |
| plpy.error("Logregr error: The tolerance cannot be negative!") |
| |
| if optimizer == "newton": |
| optimizer = "irls" |
| elif optimizer not in ("irls", "cg", "igd"): |
| plpy.error(""" Logregr error: Unknown optimizer requested. |
| Must be 'newton'/'irls', 'cg', or 'igd'. |
| """) |
| |
| return optimizer |
| |
| # ======================================================================== |
| |
| |
| def __logregr_train_compute(schema_madlib, tbl_source, tbl_output, dep_col, |
| ind_col, grouping_col, max_iter, optimizer, |
| tolerance, verbose, **kwargs): |
| """ |
| Create an output table (drop if exists) that contains the logistic |
| regression model |
| """ |
| old_msg_level = plpy.execute("select setting from pg_settings where \ |
| name='client_min_messages'")[0]['setting'] |
| if verbose: |
| plpy.execute("set client_min_messages to warning") |
| else: |
| plpy.execute("set client_min_messages to error") |
| |
| args = dict(schema_madlib = schema_madlib, |
| tbl_source = tbl_source, |
| tbl_output = tbl_output, |
| dep_col = dep_col, |
| ind_col = ind_col, |
| max_iter = max_iter, |
| optimizer = optimizer, |
| tolerance = tolerance, |
| tbl_logregr_args = __unique_string(), |
| tbl_logregr_state = __unique_string(), |
| irls = "__logregr_irls_result", |
| newton = "__logregr_irls_result", |
| cg = "__logregr_cg_result", |
| igd = "__logregr_igd_result") |
| |
| plpy.execute("select {schema_madlib}.create_schema_pg_temp()".format(**args)) |
| plpy.execute( |
| """ |
| drop table if exists pg_temp.{tbl_logregr_args}; |
| create table pg_temp.{tbl_logregr_args} as |
| select |
| {max_iter} as max_iter, |
| {tolerance} as tolerance |
| """.format(**args)) |
| |
| # return an array of dict |
| # each dict has two elements: iteration number, and grouping value array |
| if grouping_col: |
| grouping_list = _string_to_array(grouping_col) |
| for i in range(len(grouping_list)): |
| grouping_list[i] += "::text" |
| grouping_str = ','.join(grouping_list) |
| else: |
| grouping_str = "Null" |
| |
| iteration_run = __compute_logregr(schema_madlib, args["tbl_logregr_args"], |
| args["tbl_logregr_state"], tbl_source, |
| dep_col, ind_col, optimizer, |
| grouping_col = grouping_col, |
| grouping_str = grouping_str) |
| |
| grouping_str1 = "" if grouping_col is None else grouping_col + "," |
| grouping_str2 = "1 = 1" if grouping_col is None else grouping_col |
| |
| plpy.execute( |
| """ |
| drop table if exists {tbl_output}; |
| create table {tbl_output} as |
| select |
| {grouping_str1} |
| (case when (result).status = 2 then NULL::double precision[] |
| else (result).coef end) as coef, |
| (case when (result).status = 2 then NULL::double precision |
| else (result).log_likelihood end) as log_likelihood, |
| (case when (result).status = 2 then NULL::double precision[] |
| else (result).std_err end) as std_err, |
| (case when (result).status = 2 then NULL::double precision[] |
| else (result).z_stats end) as z_stats, |
| (case when (result).status = 2 then NULL::double precision[] |
| else (result).p_values end) as p_values, |
| (case when (result).status = 2 then NULL::double precision[] |
| else (result).odds_ratios end) as odds_ratios, |
| (case when (result).status = 2 then NULL::double precision |
| else (result).condition_no end) as condition_no, |
| _iteration as num_iterations |
| from |
| (select |
| {grouping_str1} |
| {schema_madlib}.{fnName}(_state) as result, |
| _iteration |
| from |
| {tbl_logregr_state}) t |
| join |
| ( |
| select |
| {grouping_str1} |
| max(_iteration) as _iteration |
| from {tbl_logregr_state} |
| group by {grouping_str2} |
| ) s |
| using ({grouping_str1} _iteration) |
| """.format(grouping_str1 = grouping_str1, |
| grouping_str2 = grouping_str2, |
| fnName = args[args["optimizer"]], |
| iteration_run = iteration_run, |
| **args)) |
| |
| failed_groups = plpy.execute( |
| """ |
| select count(*) as count |
| from {tbl_output} |
| where coef is Null |
| """.format(**args))[0]["count"] |
| all_groups = plpy.execute( |
| """ |
| select count(*) as count |
| from {tbl_output} |
| """.format(**args))[0]["count"] |
| |
| if grouping_col: |
| plpy.info(str(all_groups - failed_groups) + |
| " groups succesfully passed, and " + |
| str(failed_groups) + " groups failed") |
| |
| plpy.execute(""" |
| drop table if exists pg_temp.{tbl_logregr_args}; |
| drop table if exists pg_temp.{tbl_logregr_state} |
| """.format(**args)) |
| |
| plpy.execute("set client_min_messages to " + old_msg_level) |
| return None |