| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| |
| """ |
| @file simple_logistic.py_in |
| |
| @brief Logistic Regression: Driver functions |
| |
| @namespace simple_logistic |
| |
| @brief Logistic Regression: Driver functions |
| """ |
| import plpy |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import table_is_empty |
| # ------------------------------------------------------------------------ |
| |
| |
| def logregr_simple_train( |
| schema_madlib, source_table, out_table, dependent_varname, |
| independent_varname, max_iter=None, |
| tolerance=None, verbose=None, **kwargs): |
| """ |
| Train logistic model |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param source_table Name of relation containing the training data |
| @param out_table Name of relation where model will be outputted |
| @param dependent_varname Name of dependent column in training data (of type BOOLEAN) |
| @param independent_varname Name of independent column in training data (of type |
| DOUBLE PRECISION[]) |
| @param max_iter The maximum number of iterations that are allowed. |
| @param tolerance The precision that the results should have |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| |
| @return A composite value which is __logregr_simple_result defined in simple_logistic.sql_in |
| """ |
| |
| return __logregr_train_compute( |
| schema_madlib, source_table, out_table, dependent_varname, |
| independent_varname, max_iter, tolerance, |
| verbose, **kwargs) |
| |
| # ======================================================================== |
| |
| |
| def __logregr_train_compute(schema_madlib, tbl_source, tbl_output, dep_col, |
| ind_col, max_iter, tolerance, verbose, **kwargs): |
| """ |
| Validate the arguments |
| """ |
| if tbl_source is None or tbl_source.strip().lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid data table name!") |
| if not table_exists(tbl_source): |
| plpy.error("Logregr error: Data table does not exist!") |
| if table_is_empty(tbl_source): |
| plpy.error("Logregr error: Data table is empty!") |
| |
| if tbl_output is None or tbl_output.strip().lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid output table name!") |
| |
| if not dep_col or dep_col.strip().lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid dependent column name!") |
| |
| # if not columns_exist_in_table(tbl_source, [dep_col]): |
| # plpy.error("Logregr error: Dependent column does not exist!") |
| |
| if not ind_col or ind_col.lower() in ('null', ''): |
| plpy.error("Logregr error: Invalid independent column name!") |
| |
| if max_iter <= 0: |
| plpy.error("Logregr error: Maximum number of iterations must be positive!") |
| |
| if tolerance < 0: |
| plpy.error("Logregr error: The tolerance cannot be negative!") |
| |
| update_plan = plpy.prepare( |
| """ |
| SELECT |
| {schema_madlib}.__logregr_simple_step( |
| ({dep_col})::boolean, |
| ({ind_col})::double precision[], |
| ($1)) |
| FROM {tbl_source} |
| """.format( |
| tbl_output=tbl_output, |
| schema_madlib=schema_madlib, |
| dep_col=dep_col, |
| ind_col=ind_col, |
| tbl_source=tbl_source), ["double precision[]"]) |
| |
| state = None |
| for it in range(0, max_iter): |
| res_tuple = plpy.execute(update_plan, [state]) |
| state = res_tuple[0].values()[0] |
| |
| output_table = plpy.prepare( |
| """ |
| drop table if exists {tbl_output}; |
| create table {tbl_output} as |
| select |
| (result).coef as coef, |
| (result).log_likelihood as log_likelihood |
| from |
| ( |
| select * from |
| {schema_madlib}.__logregr_simple_finalizer($1) |
| ) result |
| """.format(schema_madlib=schema_madlib, |
| tbl_output=tbl_output), ["double precision[]"]) |
| |
| plpy.execute(output_table, [state]) |
| return None |
| # -------------------------------------------------------------------- |
| |
| |
| def logregr_simple_help_msg(schema_madlib, message, **kwargs): |
| """ Help message for logistic regression |
| |
| @param message A string, the help message indicator |
| |
| Returns: |
| A string, contains the help message |
| """ |
| if not message: |
| |
| help_string = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Binomial logistic regression models the relationship between a |
| dichotomous dependent variable and one or more predictor variables. |
| |
| The dependent variable may be a Boolean value or a categorical variable |
| that can be represented with a Boolean expression. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.logregr_train('usage') |
| |
| For a small example on using the function: |
| SELECT {schema_madlib}.logregr_train('example') |
| """ |
| elif message in ['usage', 'help', '?']: |
| |
| help_string = """ |
| ------------------------------------------------------------------ |
| USAGE |
| ------------------------------------------------------------------ |
| SELECT {schema_madlib}.logregr_train( |
| source_table, -- name of input table |
| out_table, -- name of output table |
| dependent_varname, -- name of dependent variable |
| independent_varname, -- names of independent variables |
| max_iter, -- optional, default 20, maximum iteration number |
| tolerance, -- optional, default 0.0001, the stopping threshold |
| verbose -- optional, default FALSE, whether to print useful info |
| ); |
| |
| ------------------------------------------------------------------ |
| OUTPUT |
| ------------------------------------------------------------------ |
| The output table ('out_table' above) has the following columns: |
| <...>, -- Grouping column values used during training |
| 'coef', double precision[], -- vector of fitting coefficients |
| 'log_likelihood', double precision, -- log likelihood |
| 'std_err', double precision[], -- vector of standard errors of the fitting coefficients |
| 'z_stats', double precision[], -- vector of the z-statistics of the coefficients |
| 'p_values', double precision[], -- vector of the p values |
| 'odds_ratios', double precision[], -- vector of odds ratios, exp(coefficients) |
| 'condition_no', double precision, -- the condition number |
| 'num_rows_processed', integer, -- how many rows are actually used in the computation |
| 'num_missing_rows_skipped', integer, -- number of rows that contain NULL and were skipped per group |
| 'num_iterations' double precision -- how many iterations are used in the computation per group |
| """ |
| elif message in ['example', 'examples']: |
| |
| help_string = """ |
| CREATE TABLE patients( id INTEGER NOT NULL, |
| second_attack BOOLEAN, |
| treatment INTEGER, |
| trait_anxiety INTEGER); |
| COPY patients FROM STDIN WITH DELIMITER '|'; |
| 1 | True | 1 | 70 |
| 3 | True | 1 | 50 |
| 5 | True | 0 | 40 |
| 7 | True | 0 | 75 |
| 9 | True | 0 | 70 |
| 11 | False | 1 | 65 |
| 13 | False | 1 | 45 |
| 15 | False | 1 | 40 |
| 17 | False | 0 | 55 |
| 19 | False | 0 | 50 |
| 2 | True | 1 | 80 |
| 4 | True | 0 | 60 |
| 6 | True | 0 | 65 |
| 8 | True | 0 | 80 |
| 10 | True | 0 | 60 |
| 12 | False | 1 | 50 |
| 14 | False | 1 | 35 |
| 16 | False | 1 | 50 |
| 18 | False | 0 | 45 |
| 20 | False | 0 | 60 |
| \. |
| |
| -- Drop output tables before calling the function |
| DROP TABLE IF EXISTS patients_logregr; |
| DROP TABLE IF EXISTS patients_logregr_summary; |
| |
| SELECT madlib.logregr_simple_train( 'patients', |
| 'patients_logregr', |
| 'second_attack', |
| 'ARRAY[1, treatment, trait_anxiety]'); |
| |
| SELECT * from patients_logregr; |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.logregr_simple_train('help')" |
| |
| return help_string.format(schema_madlib=schema_madlib) |