| /* ----------------------------------------------------------------------- *//** |
| * |
| * @file cross_validation.sql_in |
| * |
| * @brief SQL functions for cross validation |
| * @date January 2011 |
| * |
| * @sa For a brief introduction to the usage of cross validation, see the |
| * module description \ref grp_validation. |
| * |
| *//* ----------------------------------------------------------------------- */ |
| |
| |
| m4_include(`SQLCommon.m4') |
| |
| /** |
| @addtogroup grp_validation |
| |
| @brief Estimates the fit of a predictive model given a data set and specifications for the training, prediction, and error estimation functions. |
| |
| <div class="toc"><b>Contents</b> |
| <ul> |
| <li><a href="#cvfunction">Cross-Validation Function</a></li> |
| <li><a href="#examples">Examples</a></li> |
| <li><a href="#notes">Notes</a></li> |
| <li><a href="#background">Technical Background</a></li> |
| <li><a href="#related">Related Topics</a></li> |
| </ul> |
| </div> |
| |
| Estimates the fit of a predictive model given a data set and specifications for |
| the training, prediction, and error estimation functions. |
| |
| Cross validation, sometimes called rotation estimation, is a technique for |
| assessing how the results of a statistical analysis will generalize to an |
| independent data set. It is mainly used in settings where the goal is |
| prediction, and you want to estimate how accurately a predictive model will |
| perform in practice. |
| |
| The cross-validation function provided by this module is very flexible and |
| can work with algorithms you want to cross validate, including algorithms you |
| write yourself. Among the inputs to the cross-validation function are |
| specifications of the modelling, prediction, and error metric functions. These |
| three-part specifications include the name of the function, an array of |
| arguments to pass to the function, and an array of the data types of the |
| arguments. This makes it possible to use functions from other MADlib modules |
| or user-defined functions that you supply. |
| |
| - The modelling (training) function takes in a given data set with independent |
| and dependent variables and produces a model, which is stored in an output |
| table. |
| |
| - The prediction function takes in the model generated by the modelling function |
| and a different data set with independent variables, and produces a prediction |
| of the dependent variables based on the model, which is stored in an |
| output table. The prediction function should take a unique ID column name in |
| the data table as one of the inputs, so that the prediction result can be |
| compared with the validation values. |
| Note: Prediction function in some MADlib modules do not save results into an output |
| table. These prediction functions are not suitable for this cross-validation module. |
| |
| - The error metric function compares the prediction results with the known |
| values of the dependent variables in the data set that was fed into the |
| prediction function. It computes the error metric using the specified error |
| metric function, and stores the results in a table. |
| |
| Other inputs include the output table name, k value for the k-fold |
| cross validation, and how many folds to try. For example, you can choose to run a |
| simple validation instead of a full cross validation. |
| |
| |
| @anchor cvfunction |
| @par Cross-Validation Function |
| |
| <pre class="syntax"> |
| cross_validation_general( modelling_func, |
| modelling_params, |
| modelling_params_type, |
| param_explored, |
| explore_values, |
| predict_func, |
| predict_params, |
| predict_params_type, |
| metric_func, |
| metric_params, |
| metric_params_type, |
| data_tbl, |
| data_id, |
| id_is_random, |
| validation_result, |
| data_cols, |
| fold_num |
| )</pre> |
| \b Arguments |
| <dl class="arglist"> |
| <dt>modelling_func</dt> |
| <dd>VARCHAR. The name of the function that trains the model.</dd> |
| |
| <dt>modelling_params</dt> |
| <dd>VARCHAR[]. An array of parameters to supply to the modelling function.</dd> |
| |
| <dt>modelling_params_type</dt> |
| <dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the modelling function.</dd> |
| |
| <dt>param_explored</dt> |
| <dd>VARCHAR. The name of the parameter that will be checked to find the optimum value. The name must appear in the \e modelling_params array.</dd> |
| |
| <dt>explore_values</dt> |
| <dd>VARCHAR. The name of the parameter whose values are to be studied.</dd> |
| |
| <dt>predict_func</dt> |
| <dd>VARCHAR. The name of the prediction function.</dd> |
| |
| <dt>predict_params</dt> |
| <dd>VARCHAR[]. An array of parameters to supply to the prediction function.</dd> |
| |
| <dt>predict_params_type</dt> |
| <dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the prediction function.</dd> |
| |
| <dt>metric_func</dt> |
| <dd>VARCHAR. The name of the function for measuring errors.</dd> |
| |
| <dt>metric_params</dt> |
| <dd>VARCHAR[]. An array of parameters to supply to the error metric function.</dd> |
| |
| <dt>metric_params_type</dt> |
| <dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the metric function.</dd> |
| |
| <dt>data_tbl</dt> |
| <dd>VARCHAR. The name of the data table that will be split into training and validation parts.</dd> |
| |
| <dt>data_id</dt> |
| <dd>VARCHAR. The name of the column containing a unique ID associated with |
| each row, or NULL if the table has no such column. |
| |
| Ideally, the data set has a unique ID for each row so that it is easier to |
| partition the data set into the training part and the validation part. Set the |
| \e id_is_random argument to inform the cross-validation function whether |
| the ID value is randomly assigned to each row. If it is not randomly |
| assigned, the cross-validation function generates a random ID for each row. |
| </dd> |
| |
| <dt>id_is_random</dt> |
| <dd>BOOLEAN. TRUE if the provided ID is randomly assigned to each row.</dd> |
| |
| <dt>validation_result</dt> |
| <dd>VARCHAR. The name of the table to store the output of the cross-validation function. The output table has the following columns: |
| <table class="output"> |
| <tr> |
| <th>param_explored</th> |
| <td>The name of the parameter checked to find the optimum value. This is the |
| same name specified in the \e param_explored argument of the \e cross_validation_general() function.</td> |
| </tr> |
| <tr> |
| <th>average error</th> |
| <td>The average of the errors computed by the error metric function.</td> |
| </tr> |
| <tr> |
| <th>standard deviation of error</th> |
| <td>The standard deviation of the errors.</td> |
| </tr> |
| </table> |
| </dd> |
| |
| <dt>data_cols</dt> |
| <dd>A comma-separated list of names of data columns to use in the calculation. |
| When its value is NULL, the function will automatically figure out all the column names of the data table. |
| This is only used if the <em>data_id</em> argument is NULL, otherwise it is |
| ignored. |
| |
| If the data set has no unique ID for each row, the cross-validation function |
| copies the data set to a temporary table with a randomly assigned ID column. |
| Setting this argument to the list of independent and dependent |
| variables that are to be used in the calculation minimizes the copying |
| workload by only copying the required data. |
| |
| The new temporary table is dropped after the computation has finished. |
| </dd> |
| <dt><em>fold_num</em></dt> |
| <dd>INTEGER, default: 10. Value of k. How many folds validation? Each validation uses 1/fold_num |
| fraction of the data for validation.</dd> |
| </dl> |
| |
| The parameter arrays for the modelling, prediction and metric functions can include the following special keywords: |
| |
| - <em>\%data%</em> – The argument position for training/validation data |
| |
| - <em>\%model%</em> – The argument position for the output/input of modelling/prediction function |
| |
| - <em>\%id%</em> – The argument position of the unique ID column (user-provided or generated by the cross-validation function, as described above) |
| |
| - <em>\%prediction%</em> – The argument position for the output/input of prediction/metric function |
| |
| - <em>\%error%</em> – The argument position for the output of the metric function |
| |
| <b>Note</b>: If the argument <em>explore_values</em> is NULL or has zero length, then the cross-validation function will only run a data folding. |
| |
| |
| @anchor examples |
| @examp |
| |
| -# Load some sample data: |
| <pre class="example"> |
| DROP TABLE IF EXISTS houses; |
| CREATE TABLE houses ( id INT, |
| tax INT, |
| bedroom INT, |
| bath FLOAT, |
| size INT, |
| lot INT, |
| zipcode INT, |
| price INT, |
| high_priced BOOLEAN |
| ); |
| INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, high_priced) VALUES |
| (1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301, 'f'::boolean), |
| (2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301, 'f'::boolean), |
| (3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301, 'f'::boolean), |
| (4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301, 'f'::boolean), |
| (5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301, 't'::boolean), |
| (6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301, 'f'::boolean), |
| (7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301, 't'::boolean), |
| (8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301, 't'::boolean), |
| (9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301, 't'::boolean), |
| (10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301, 't'::boolean), |
| (11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301, 'f'::boolean), |
| (12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301, 't'::boolean), |
| (13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301, 't'::boolean), |
| (14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301, 't'::boolean), |
| (15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301, 'f'::boolean), |
| (16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010, 'f'::boolean), |
| (17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010, 't'::boolean), |
| (18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010, 'f'::boolean), |
| (19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean), |
| (20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean), |
| (21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean), |
| (22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean), |
| (23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean), |
| (24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean), |
| (25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean), |
| (26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean), |
| (27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean), |
| (28 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean), |
| (29 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean), |
| (30 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean), |
| (31 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean), |
| (32 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean), |
| (33 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean), |
| (34 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean), |
| (35 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean), |
| (36 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean); |
| </pre> |
| |
| -# Use the general function to explore lambda values |
| for elastic net. (Note that elastic net also has a |
| built in cross validation function |
| for selecting elastic net control parameter alpha and |
| regularization value lambda.) |
| <pre class="example"> |
| DROP TABLE IF EXISTS houses_cv_results; |
| SELECT madlib.cross_validation_general( |
| -- modelling_func |
| 'madlib.elastic_net_train', |
| -- modelling_params |
| '{%%data%, %%model%, price, "array[tax, bath, size]", gaussian, 0.5, lambda, TRUE, NULL, fista, |
| "{eta = 2, max_stepsize = 2, use_active_set = t}", |
| NULL, 200, 1e-6}'::varchar[], |
| -- modelling_params_type |
| '{varchar, varchar, varchar, varchar, varchar, double precision, |
| double precision, boolean, varchar, varchar, varchar, varchar, |
| integer, double precision}'::varchar[], |
| -- param_explored |
| 'lambda', |
| -- explore_values |
| '{0.1, 0.2}'::varchar[], |
| -- predict_func |
| 'madlib.elastic_net_predict', |
| -- predict_params |
| '{%%model%, %%data%, %%id%, %%prediction%}'::varchar[], |
| -- predict_params_type |
| '{text, text, text, text}'::varchar[], |
| -- metric_func |
| 'madlib.mse_error', |
| -- metric_params |
| '{%%prediction%, %%data%, %%id%, price, %%error%}'::varchar[], |
| -- metric_params_type |
| '{varchar, varchar, varchar, varchar, varchar}'::varchar[], |
| -- data_tbl |
| 'houses', |
| -- data_id |
| 'id', |
| -- id_is_random |
| FALSE, |
| -- validation_result |
| 'houses_cv_results', |
| -- data_cols |
| NULL, |
| -- fold_num |
| 3 |
| ); |
| SELECT * FROM houses_cv_results; |
| </pre> |
| Results from the lambda values explored: |
| <pre class="result"> |
| lambda | mean_squared_error_avg | mean_squared_error_stddev |
| --------+------------------------+--------------------------- |
| 0.1 | 1094965503.24269 | 411974996.039577 |
| 0.2 | 1093350170.40664 | 411072137.632718 |
| (2 rows) |
| </pre> |
| |
| -# Here we use the general function to explore |
| maximum number of iterations for logistic regression: |
| <pre class="example"> |
| DROP TABLE IF EXISTS houses_logregr_cv; |
| SELECT madlib.cross_validation_general( |
| -- modelling_func |
| 'madlib.logregr_train', |
| -- modelling_params |
| '{%%data%, %%model%, high_priced, "ARRAY[1, bedroom, bath, size]", NULL, max_iter}'::varchar[], |
| -- modelling_params_type |
| '{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[], |
| -- param_explored |
| 'max_iter', |
| -- explore_values |
| '{2, 10, 40, 100}'::varchar[], |
| -- predict_func |
| 'madlib.cv_logregr_predict', |
| -- predict_params |
| '{%%model%, %%data%, "ARRAY[1, bedroom, bath, size]", id, %%prediction%}'::varchar[], |
| -- predict_params_type |
| '{varchar, varchar,varchar,varchar,varchar}'::varchar[], |
| -- metric_func |
| 'madlib.misclassification_avg', |
| -- metric_params |
| '{%%prediction%, %%data%, id, high_priced, %%error%}'::varchar[], |
| -- metric_params_type |
| '{varchar, varchar, varchar, varchar, varchar}'::varchar[], |
| -- data_tbl |
| 'houses', |
| -- data_id |
| 'id', |
| -- id_is_random |
| FALSE, |
| -- validation_result |
| 'houses_logregr_cv', |
| -- data_cols |
| NULL, |
| -- fold_num |
| 5 |
| ); |
| SELECT * FROM houses_logregr_cv; |
| </pre> |
| Results from the explored number of iterations: |
| <pre class="result"> |
| max_iter | error_rate_avg | error_rate_stddev |
| ----------+------------------------+-------------------------------------------- |
| 2 | 0.19285714285714285714 | 0.1589185390091927774733662830554976076700 |
| 10 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213 |
| 40 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213 |
| 100 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213 |
| (4 rows) |
| </pre> |
| |
| @anchor notes |
| @par Notes |
| |
| The lock management parameter <em>max_locks_per_transaction</em>, |
| which usually is set to the default value |
| of 64, limits the number of tables that can be dropped inside a single |
| transaction (the cross-validation function). Thus, the number of different |
| values of <em>param_explored</em> (or the length of the |
| <em>explored_values</em> array) cannot be too large. For 10-fold cross |
| validation, the limit of <tt>length(<em>explored_values</em>)</tt> is around 40. If the |
| limit is exceeded, you may get an "out of shared memory" error because |
| <em>max_locks_per_transaction</em> is exceeded. |
| |
| One way to overcome this limitation is to run the cross-validation function |
| multiple times, with each run covering a different region of values of the |
| parameter. |
| |
| Note that MADlib implements cross-validation functions within certain |
| individual modules, where it is possible to optimize the calculation |
| to avoid dropping tables and prevent exceeding the \e |
| max_locks_per_transaction limitation. Since module-specific cross-validation |
| functions depend upon the implementation details of the modules to perform the |
| optimization, they will not be as flexible as the generalized cross-validation |
| function provided here. |
| |
| @anchor background |
| @par Technical Background |
| |
| One round of cross validation involves partitioning a sample of data into |
| complementary subsets, performing the analysis on one subset (called the |
| training set), and validating the analysis on the other subset (called the |
| validation set or test set). To reduce variability, multiple rounds of |
| cross validation are performed using different partitions, and the validation |
| results are averaged over the rounds. |
| |
| In k-fold cross validation, the original sample is randomly partitioned into k |
| equal sized subsamples. Of the k subsamples, a single subsample is retained as |
| the validation data for testing the model, and the remaining k−1 subsamples |
| are used as training data. The cross-validation process is repeated k |
| times (the folds), with each of the k subsamples used exactly once as the |
| validation data. The k results from the folds can be averaged (or |
| otherwise combined) to produce a single estimation. The advantage of this |
| method over repeated random sub-sampling is that all observations are used for |
| both training and validation, and each observation is used for validation |
| exactly once. 10-fold cross validation is commonly used, but in general k |
| remains an unfixed parameter. |
| |
| @anchor related |
| @par Related Topics |
| |
| File cross_validation.sql_in documenting the SQL functions. |
| |
| */ |
| |
| ------------------------------------------------------------------------ |
| /* |
| * @brief Perform cross validation for modules that conforms with a fixed SQL API |
| * Note: There is a lock number limitation of this function. It is flexible to use, so that the user can |
| * try CV method on their own functions. On the other hand, cross_validation function does not have the |
| * lock number limitation. |
| * |
| * @param modelling_func Name of function that trains the model |
| * @param modelling_params Array of parameters for modelling function |
| * @param modelling_params_type Types of each parameters for modelling function |
| * @param param_explored Name of parameter that will be checked to find the optimum value, the same name must also appear in the array of modelling_params |
| * @param explore_values Values of this parameter that will be studied |
| * @param predict_func Name of function for prediction |
| * @param predict_params Array of parameters for prediction function |
| * @param predict_params_type Types of each parameters for prediction function |
| * @param metric_func Name of function for measuring errors |
| * @param metric_params Array of parameters for error metric function |
| * @param metric_params_type Types of each parameters for metric function |
| * @param data_tbl Data table which will be split into training and validation parts |
| * @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table |
| * @param id_is_random Whether the provided ID is randomly assigned to each row |
| * @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function |
| * @param fold_num Value of k. How many folds validation? Each validation uses 1/fold_num fraction of the data for validation. Deafult value: 10. |
| * @param data_cols Names of data columns that are going to be used. It is only useful when <em>data_id</em> is NULL, otherwise it is ignored. |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general( |
| modelling_func VARCHAR, -- function for setting up the model |
| modelling_params VARCHAR[], -- parameters for modelling |
| modelling_params_type VARCHAR[], -- parameter types for modelling |
| -- |
| param_explored VARCHAR, -- which parameter will be studied using validation |
| explore_values VARCHAR[], -- values that will be explored for this parameter |
| -- |
| predict_func VARCHAR, -- function for predicting using the model |
| predict_params VARCHAR[], -- parameters for prediction |
| predict_params_type VARCHAR[], -- parameter types for prediction |
| -- |
| metric_func VARCHAR, -- function that computes the error metric |
| metric_params VARCHAR[], -- parameters for metric |
| metric_params_type VARCHAR[], -- parameter types for metric |
| -- |
| data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts |
| data_id VARCHAR, -- user provide a unique ID for each row |
| id_is_random BOOLEAN, -- the ID provided by user is random |
| -- |
| validation_result VARCHAR, -- store the result: param values, error, +/- |
| -- |
| data_cols VARCHAR[], -- names of data columns that are going to be used |
| n_folds INTEGER -- how many fold validation, default: 10 |
| ) RETURNS VOID AS $$ |
| PythonFunction(validation, cross_validation, cross_validation_general) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| ------------------------------------------------------------------------ |
| |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general( |
| modelling_func VARCHAR, -- function for setting up the model |
| modelling_params VARCHAR[], -- parameters for modelling |
| modelling_params_type VARCHAR[], -- parameter types for modelling |
| -- |
| param_explored VARCHAR, -- which parameter will be studied using validation |
| explore_values VARCHAR[], -- values that will be explored for this parameter |
| -- |
| predict_func VARCHAR, -- function for predicting using the model |
| predict_params VARCHAR[], -- parameters for prediction |
| predict_params_type VARCHAR[], -- parameter types for prediction |
| -- |
| metric_func VARCHAR, -- function that computes the error metric |
| metric_params VARCHAR[], -- parameters for prediction |
| metric_params_type VARCHAR[], -- parameter types for prediction |
| -- |
| data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts |
| data_id VARCHAR, -- user provide a unique ID for each row |
| id_is_random BOOLEAN, -- the ID provided by user is random |
| -- |
| validation_result VARCHAR, -- store the result: param values, error, +/- |
| -- |
| data_cols VARCHAR[] -- names of data columns that are going to be used |
| ) RETURNS VOID AS $$ |
| BEGIN |
| PERFORM MADLIB_SCHEMA.cross_validation_general($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,10); |
| END; |
| $$ LANGUAGE plpgsql VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| ------------------------------------------------------------------------ |
| |
| |
| /** |
| * @brief A wrapper for linear regression |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_train( |
| tbl_source VARCHAR, |
| col_ind_var VARCHAR, |
| col_dep_var VARCHAR, |
| tbl_result VARCHAR |
| ) RETURNS VOID AS $$ |
| PythonFunction(validation, cross_validation, cv_linregr_train) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| /** |
| * @brief A wrapper for linear regression prediction |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_predict( |
| tbl_model VARCHAR, |
| tbl_newdata VARCHAR, |
| col_ind_var VARCHAR, |
| col_id VARCHAR, -- ID column |
| tbl_predict VARCHAR |
| ) RETURNS VOID AS $$ |
| PythonFunction(validation, cross_validation, cv_linregr_predict) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| -- compare the prediction and actual values |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.mse_error( |
| tbl_prediction VARCHAR, -- predicted values |
| tbl_actual VARCHAR, |
| id_actual VARCHAR, |
| values_actual VARCHAR, |
| tbl_error VARCHAR |
| ) RETURNS VOID AS $$ |
| DECLARE |
| error DOUBLE PRECISION; |
| old_messages VARCHAR; |
| BEGIN |
| old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages'); |
| EXECUTE 'SET client_min_messages TO warning'; |
| |
| EXECUTE ' |
| CREATE TABLE '|| tbl_error ||' AS |
| SELECT |
| avg(('|| tbl_prediction ||'.prediction - ('|| values_actual ||'))^2) as mean_squared_error |
| FROM |
| '|| tbl_prediction ||', |
| '|| tbl_actual ||' |
| WHERE |
| '|| tbl_prediction ||'.'|| id_actual ||' = '|| tbl_actual ||'.'|| id_actual; |
| |
| EXECUTE 'SET client_min_messages TO ' || old_messages; |
| END; |
| $$ LANGUAGE plpgsql VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| -- compare the prediction and actual values |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.misclassification_avg( |
| tbl_prediction VARCHAR, -- predicted values |
| tbl_actual VARCHAR, |
| id_actual VARCHAR, |
| values_actual VARCHAR, |
| tbl_error VARCHAR |
| ) RETURNS VOID AS $$ |
| DECLARE |
| error DOUBLE PRECISION; |
| old_messages VARCHAR; |
| BEGIN |
| old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages'); |
| EXECUTE 'SET client_min_messages TO warning'; |
| |
| EXECUTE ' |
| CREATE TABLE '|| tbl_error ||' AS |
| SELECT |
| avg(CASE WHEN '|| tbl_prediction ||'.prediction = ('|| values_actual ||') |
| THEN 0. ELSE 1. END) as error_rate |
| FROM |
| '|| tbl_prediction ||', |
| '|| tbl_actual ||' |
| WHERE |
| '|| tbl_prediction ||'.'|| id_actual ||' = '|| tbl_actual ||'.'|| id_actual; |
| |
| EXECUTE 'SET client_min_messages TO ' || old_messages; |
| END; |
| $$ LANGUAGE plpgsql VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| ------------------------------------------------------------------------ |
| |
| /** |
| * @brief A prediction function for logistic regression |
| * The result is stored in the table of tbl_predict |
| * |
| * This function can be used together with cross-validation |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_predict( |
| tbl_model VARCHAR, |
| tbl_newdata VARCHAR, |
| col_ind_var VARCHAR, |
| col_id VARCHAR, |
| tbl_predict VARCHAR |
| ) RETURNS VOID AS $$ |
| PythonFunction(validation, cross_validation, cv_logregr_predict) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |
| |
| /** |
| * @brief Metric function for logistic regression |
| * |
| * @param coef Logistic fitting coefficients. Note: MADlib logregr_train |
| * function, unlike elastic_net, does not produce a separate intercept term. |
| * @param col_ind Independent variable, an array |
| * @param col_dep Dependent variable |
| * |
| * returns 1 if the prediction is the same as col_dep, otherwise 0 |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_accuracy( |
| coef DOUBLE PRECISION[], |
| col_ind DOUBLE PRECISION[], |
| col_dep BOOLEAN |
| ) RETURNS INTEGER AS $$ |
| DECLARE |
| logregr_result BOOLEAN; |
| BEGIN |
| logregr_result := MADLIB_SCHEMA.logregr_predict(coef, col_ind); |
| RETURN (logregr_result = col_dep)::INTEGER; |
| END; |
| $$ LANGUAGE plpgsql |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `'); |
| |
| /** |
| * @brief Metric function for logistic regression |
| * |
| * It computes the percentage of correct predictions. |
| * The result is stored in the table of tbl_accuracy |
| */ |
| CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_accuracy( |
| tbl_predict VARCHAR, |
| tbl_source VARCHAR, |
| col_id VARCHAR, |
| col_dep_var VARCHAR, |
| tbl_accuracy VARCHAR |
| ) RETURNS VOID AS $$ |
| PythonFunction(validation, cross_validation, cv_logregr_accuracy) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); |