blob: 77b2c2b4a57b9f1906a8dda95901ceb9112b91a3 [file] [log] [blame]
/* ----------------------------------------------------------------------- *//**
*
* @file cross_validation.sql_in
*
* @brief SQL functions for cross validation
* @date January 2011
*
* @sa For a brief introduction to the usage of cross validation, see the
* module description \ref grp_validation.
*
*//* ----------------------------------------------------------------------- */
m4_include(`SQLCommon.m4')
/**
@addtogroup grp_validation
@brief Estimates the fit of a predictive model given a data set and specifications for the training, prediction, and error estimation functions.
<div class="toc"><b>Contents</b>
<ul>
<li><a href="#cvfunction">Cross-Validation Function</a></li>
<li><a href="#examples">Examples</a></li>
<li><a href="#notes">Notes</a></li>
<li><a href="#background">Technical Background</a></li>
<li><a href="#related">Related Topics</a></li>
</ul>
</div>
Estimates the fit of a predictive model given a data set and specifications for
the training, prediction, and error estimation functions.
Cross validation, sometimes called rotation estimation, is a technique for
assessing how the results of a statistical analysis will generalize to an
independent data set. It is mainly used in settings where the goal is
prediction, and you want to estimate how accurately a predictive model will
perform in practice.
The cross-validation function provided by this module is very flexible and
can work with algorithms you want to cross validate, including algorithms you
write yourself. Among the inputs to the cross-validation function are
specifications of the modelling, prediction, and error metric functions. These
three-part specifications include the name of the function, an array of
arguments to pass to the function, and an array of the data types of the
arguments. This makes it possible to use functions from other MADlib modules
or user-defined functions that you supply.
- The modelling (training) function takes in a given data set with independent
and dependent variables and produces a model, which is stored in an output
table.
- The prediction function takes in the model generated by the modelling function
and a different data set with independent variables, and produces a prediction
of the dependent variables based on the model, which is stored in an
output table. The prediction function should take a unique ID column name in
the data table as one of the inputs, so that the prediction result can be
compared with the validation values.
Note: Prediction function in some MADlib modules do not save results into an output
table. These prediction functions are not suitable for this cross-validation module.
- The error metric function compares the prediction results with the known
values of the dependent variables in the data set that was fed into the
prediction function. It computes the error metric using the specified error
metric function, and stores the results in a table.
Other inputs include the output table name, k value for the k-fold
cross validation, and how many folds to try. For example, you can choose to run a
simple validation instead of a full cross validation.
@anchor cvfunction
@par Cross-Validation Function
<pre class="syntax">
cross_validation_general( modelling_func,
modelling_params,
modelling_params_type,
param_explored,
explore_values,
predict_func,
predict_params,
predict_params_type,
metric_func,
metric_params,
metric_params_type,
data_tbl,
data_id,
id_is_random,
validation_result,
data_cols,
fold_num
)</pre>
\b Arguments
<dl class="arglist">
<dt>modelling_func</dt>
<dd>VARCHAR. The name of the function that trains the model.</dd>
<dt>modelling_params</dt>
<dd>VARCHAR[]. An array of parameters to supply to the modelling function.</dd>
<dt>modelling_params_type</dt>
<dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the modelling function.</dd>
<dt>param_explored</dt>
<dd>VARCHAR. The name of the parameter that will be checked to find the optimum value. The name must appear in the \e modelling_params array.</dd>
<dt>explore_values</dt>
<dd>VARCHAR. The name of the parameter whose values are to be studied.</dd>
<dt>predict_func</dt>
<dd>VARCHAR. The name of the prediction function.</dd>
<dt>predict_params</dt>
<dd>VARCHAR[]. An array of parameters to supply to the prediction function.</dd>
<dt>predict_params_type</dt>
<dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the prediction function.</dd>
<dt>metric_func</dt>
<dd>VARCHAR. The name of the function for measuring errors.</dd>
<dt>metric_params</dt>
<dd>VARCHAR[]. An array of parameters to supply to the error metric function.</dd>
<dt>metric_params_type</dt>
<dd>VARCHAR[]. An array of data type names for each of the parameters supplied to the metric function.</dd>
<dt>data_tbl</dt>
<dd>VARCHAR. The name of the data table that will be split into training and validation parts.</dd>
<dt>data_id</dt>
<dd>VARCHAR. The name of the column containing a unique ID associated with
each row, or NULL if the table has no such column.
Ideally, the data set has a unique ID for each row so that it is easier to
partition the data set into the training part and the validation part. Set the
\e id_is_random argument to inform the cross-validation function whether
the ID value is randomly assigned to each row. If it is not randomly
assigned, the cross-validation function generates a random ID for each row.
</dd>
<dt>id_is_random</dt>
<dd>BOOLEAN. TRUE if the provided ID is randomly assigned to each row.</dd>
<dt>validation_result</dt>
<dd>VARCHAR. The name of the table to store the output of the cross-validation function. The output table has the following columns:
<table class="output">
<tr>
<th>param_explored</th>
<td>The name of the parameter checked to find the optimum value. This is the
same name specified in the \e param_explored argument of the \e cross_validation_general() function.</td>
</tr>
<tr>
<th>average error</th>
<td>The average of the errors computed by the error metric function.</td>
</tr>
<tr>
<th>standard deviation of error</th>
<td>The standard deviation of the errors.</td>
</tr>
</table>
</dd>
<dt>data_cols</dt>
<dd>A comma-separated list of names of data columns to use in the calculation.
When its value is NULL, the function will automatically figure out all the column names of the data table.
This is only used if the <em>data_id</em> argument is NULL, otherwise it is
ignored.
If the data set has no unique ID for each row, the cross-validation function
copies the data set to a temporary table with a randomly assigned ID column.
Setting this argument to the list of independent and dependent
variables that are to be used in the calculation minimizes the copying
workload by only copying the required data.
The new temporary table is dropped after the computation has finished.
</dd>
<dt><em>fold_num</em></dt>
<dd>INTEGER, default: 10. Value of k. How many folds validation? Each validation uses 1/fold_num
fraction of the data for validation.</dd>
</dl>
The parameter arrays for the modelling, prediction and metric functions can include the following special keywords:
- <em>\%data%</em> &ndash; The argument position for training/validation data
- <em>\%model%</em> &ndash; The argument position for the output/input of modelling/prediction function
- <em>\%id%</em> &ndash; The argument position of the unique ID column (user-provided or generated by the cross-validation function, as described above)
- <em>\%prediction%</em> &ndash; The argument position for the output/input of prediction/metric function
- <em>\%error%</em> &ndash; The argument position for the output of the metric function
<b>Note</b>: If the argument <em>explore_values</em> is NULL or has zero length, then the cross-validation function will only run a data folding.
@anchor examples
@examp
-# Load some sample data:
<pre class="example">
DROP TABLE IF EXISTS houses;
CREATE TABLE houses ( id INT,
tax INT,
bedroom INT,
bath FLOAT,
size INT,
lot INT,
zipcode INT,
price INT,
high_priced BOOLEAN
);
INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, high_priced) VALUES
(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301, 'f'::boolean),
(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301, 'f'::boolean),
(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301, 'f'::boolean),
(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301, 'f'::boolean),
(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301, 't'::boolean),
(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301, 'f'::boolean),
(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301, 't'::boolean),
(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301, 't'::boolean),
(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301, 't'::boolean),
(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301, 't'::boolean),
(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301, 'f'::boolean),
(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301, 't'::boolean),
(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301, 't'::boolean),
(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301, 't'::boolean),
(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301, 'f'::boolean),
(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010, 'f'::boolean),
(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010, 't'::boolean),
(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010, 'f'::boolean),
(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),
(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),
(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),
(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),
(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),
(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),
(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),
(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),
(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean),
(28 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),
(29 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),
(30 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),
(31 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),
(32 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),
(33 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),
(34 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),
(35 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),
(36 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean);
</pre>
-# Use the general function to explore lambda values
for elastic net. (Note that elastic net also has a
built in cross validation function
for selecting elastic net control parameter alpha and
regularization value lambda.)
<pre class="example">
DROP TABLE IF EXISTS houses_cv_results;
SELECT madlib.cross_validation_general(
-- modelling_func
'madlib.elastic_net_train',
-- modelling_params
'{%%data%, %%model%, price, "array[tax, bath, size]", gaussian, 0.5, lambda, TRUE, NULL, fista,
"{eta = 2, max_stepsize = 2, use_active_set = t}",
NULL, 10000, 1e-6}'::varchar[],
-- modelling_params_type
'{varchar, varchar, varchar, varchar, varchar, double precision,
double precision, boolean, varchar, varchar, varchar, varchar,
integer, double precision}'::varchar[],
-- param_explored
'lambda',
-- explore_values
'{0.1, 0.2}'::varchar[],
-- predict_func
'madlib.elastic_net_predict',
-- predict_params
'{%%model%, %%data%, %%id%, %%prediction%}'::varchar[],
-- predict_params_type
'{text, text, text, text}'::varchar[],
-- metric_func
'madlib.mse_error',
-- metric_params
'{%%prediction%, %%data%, %%id%, price, %%error%}'::varchar[],
-- metric_params_type
'{varchar, varchar, varchar, varchar, varchar}'::varchar[],
-- data_tbl
'houses',
-- data_id
'id',
-- id_is_random
FALSE,
-- validation_result
'houses_cv_results',
-- data_cols
NULL,
-- fold_num
3
);
SELECT * FROM houses_cv_results;
</pre>
Results from the lambda values explored:
<pre class="result">
lambda | mean_squared_error_avg | mean_squared_error_stddev
--------+------------------------+---------------------------
0.1 | 1194685622.1604 | 366687470.779826
0.2 | 1181768409.98238 | 352203200.758414
(2 rows)
</pre>
-# Here we use the general function to explore
maximum number of iterations for logistic regression:
<pre class="example">
DROP TABLE IF EXISTS houses_logregr_cv;
SELECT madlib.cross_validation_general(
-- modelling_func
'madlib.logregr_train',
-- modelling_params
'{%%data%, %%model%, high_priced, "ARRAY[1, bedroom, bath, size]", NULL, max_iter}'::varchar[],
-- modelling_params_type
'{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[],
-- param_explored
'max_iter',
-- explore_values
'{2, 10, 40, 100}'::varchar[],
-- predict_func
'madlib.cv_logregr_predict',
-- predict_params
'{%%model%, %%data%, "ARRAY[1, bedroom, bath, size]", id, %%prediction%}'::varchar[],
-- predict_params_type
'{varchar, varchar,varchar,varchar,varchar}'::varchar[],
-- metric_func
'madlib.misclassification_avg',
-- metric_params
'{%%prediction%, %%data%, id, high_priced, %%error%}'::varchar[],
-- metric_params_type
'{varchar, varchar, varchar, varchar, varchar}'::varchar[],
-- data_tbl
'houses',
-- data_id
'id',
-- id_is_random
FALSE,
-- validation_result
'houses_logregr_cv',
-- data_cols
NULL,
-- fold_num
5
);
SELECT * FROM houses_logregr_cv;
</pre>
Results from the explored number of iterations:
<pre class="result">
max_iter | error_rate_avg | error_rate_stddev
----------+------------------------+--------------------------------------------
2 | 0.28214285714285714286 | 0.2053183193114972855562362870460638951565
10 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
40 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
100 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
(4 rows)
</pre>
@anchor notes
@par Notes
The lock management parameter <em>max_locks_per_transaction</em>,
which usually is set to the default value
of 64, limits the number of tables that can be dropped inside a single
transaction (the cross-validation function). Thus, the number of different
values of <em>param_explored</em> (or the length of the
<em>explored_values</em> array) cannot be too large. For 10-fold cross
validation, the limit of <tt>length(<em>explored_values</em>)</tt> is around 40. If the
limit is exceeded, you may get an "out of shared memory" error because
<em>max_locks_per_transaction</em> is exceeded.
One way to overcome this limitation is to run the cross-validation function
multiple times, with each run covering a different region of values of the
parameter.
Note that MADlib implements cross-validation functions within certain
individual modules, where it is possible to optimize the calculation
to avoid dropping tables and prevent exceeding the \e
max_locks_per_transaction limitation. Since module-specific cross-validation
functions depend upon the implementation details of the modules to perform the
optimization, they will not be as flexible as the generalized cross-validation
function provided here.
@anchor background
@par Technical Background
One round of cross validation involves partitioning a sample of data into
complementary subsets, performing the analysis on one subset (called the
training set), and validating the analysis on the other subset (called the
validation set or test set). To reduce variability, multiple rounds of
cross validation are performed using different partitions, and the validation
results are averaged over the rounds.
In k-fold cross validation, the original sample is randomly partitioned into k
equal sized subsamples. Of the k subsamples, a single subsample is retained as
the validation data for testing the model, and the remaining k&minus;1 subsamples
are used as training data. The cross-validation process is repeated k
times (the folds), with each of the k subsamples used exactly once as the
validation data. The k results from the folds can be averaged (or
otherwise combined) to produce a single estimation. The advantage of this
method over repeated random sub-sampling is that all observations are used for
both training and validation, and each observation is used for validation
exactly once. 10-fold cross validation is commonly used, but in general k
remains an unfixed parameter.
@anchor related
@par Related Topics
File cross_validation.sql_in documenting the SQL functions.
*/
------------------------------------------------------------------------
/*
* @brief Perform cross validation for modules that conforms with a fixed SQL API
* Note: There is a lock number limitation of this function. It is flexible to use, so that the user can
* try CV method on their own functions. On the other hand, cross_validation function does not have the
* lock number limitation.
*
* @param modelling_func Name of function that trains the model
* @param modelling_params Array of parameters for modelling function
* @param modelling_params_type Types of each parameters for modelling function
* @param param_explored Name of parameter that will be checked to find the optimum value, the same name must also appear in the array of modelling_params
* @param explore_values Values of this parameter that will be studied
* @param predict_func Name of function for prediction
* @param predict_params Array of parameters for prediction function
* @param predict_params_type Types of each parameters for prediction function
* @param metric_func Name of function for measuring errors
* @param metric_params Array of parameters for error metric function
* @param metric_params_type Types of each parameters for metric function
* @param data_tbl Data table which will be split into training and validation parts
* @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table
* @param id_is_random Whether the provided ID is randomly assigned to each row
* @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function
* @param fold_num Value of k. How many folds validation? Each validation uses 1/fold_num fraction of the data for validation. Deafult value: 10.
* @param data_cols Names of data columns that are going to be used. It is only useful when <em>data_id</em> is NULL, otherwise it is ignored.
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general(
modelling_func VARCHAR, -- function for setting up the model
modelling_params VARCHAR[], -- parameters for modelling
modelling_params_type VARCHAR[], -- parameter types for modelling
--
param_explored VARCHAR, -- which parameter will be studied using validation
explore_values VARCHAR[], -- values that will be explored for this parameter
--
predict_func VARCHAR, -- function for predicting using the model
predict_params VARCHAR[], -- parameters for prediction
predict_params_type VARCHAR[], -- parameter types for prediction
--
metric_func VARCHAR, -- function that computes the error metric
metric_params VARCHAR[], -- parameters for metric
metric_params_type VARCHAR[], -- parameter types for metric
--
data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts
data_id VARCHAR, -- user provide a unique ID for each row
id_is_random BOOLEAN, -- the ID provided by user is random
--
validation_result VARCHAR, -- store the result: param values, error, +/-
--
data_cols VARCHAR[], -- names of data columns that are going to be used
n_folds INTEGER -- how many fold validation, default: 10
) RETURNS VOID AS $$
PythonFunction(validation, cross_validation, cross_validation_general)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
------------------------------------------------------------------------
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general(
modelling_func VARCHAR, -- function for setting up the model
modelling_params VARCHAR[], -- parameters for modelling
modelling_params_type VARCHAR[], -- parameter types for modelling
--
param_explored VARCHAR, -- which parameter will be studied using validation
explore_values VARCHAR[], -- values that will be explored for this parameter
--
predict_func VARCHAR, -- function for predicting using the model
predict_params VARCHAR[], -- parameters for prediction
predict_params_type VARCHAR[], -- parameter types for prediction
--
metric_func VARCHAR, -- function that computes the error metric
metric_params VARCHAR[], -- parameters for prediction
metric_params_type VARCHAR[], -- parameter types for prediction
--
data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts
data_id VARCHAR, -- user provide a unique ID for each row
id_is_random BOOLEAN, -- the ID provided by user is random
--
validation_result VARCHAR, -- store the result: param values, error, +/-
--
data_cols VARCHAR[] -- names of data columns that are going to be used
) RETURNS VOID AS $$
BEGIN
PERFORM MADLIB_SCHEMA.cross_validation_general($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,10);
END;
$$ LANGUAGE plpgsql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
------------------------------------------------------------------------
/**
* @brief A wrapper for linear regression
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_train(
tbl_source VARCHAR,
col_ind_var VARCHAR,
col_dep_var VARCHAR,
tbl_result VARCHAR
) RETURNS VOID AS $$
PythonFunction(validation, cross_validation, cv_linregr_train)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
/**
* @brief A wrapper for linear regression prediction
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_predict(
tbl_model VARCHAR,
tbl_newdata VARCHAR,
col_ind_var VARCHAR,
col_id VARCHAR, -- ID column
tbl_predict VARCHAR
) RETURNS VOID AS $$
PythonFunction(validation, cross_validation, cv_linregr_predict)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-- compare the prediction and actual values
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.mse_error(
tbl_prediction VARCHAR, -- predicted values
tbl_actual VARCHAR,
id_actual VARCHAR,
values_actual VARCHAR,
tbl_error VARCHAR
) RETURNS VOID AS $$
DECLARE
error DOUBLE PRECISION;
old_messages VARCHAR;
BEGIN
old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages');
EXECUTE 'SET client_min_messages TO warning';
EXECUTE '
CREATE TABLE '|| tbl_error ||' AS
SELECT
avg(('|| tbl_prediction ||'.prediction - ('|| values_actual ||'))^2) as mean_squared_error
FROM
'|| tbl_prediction ||',
'|| tbl_actual ||'
WHERE
'|| tbl_prediction ||'.'|| id_actual ||' = '|| tbl_actual ||'.'|| id_actual;
EXECUTE 'SET client_min_messages TO ' || old_messages;
END;
$$ LANGUAGE plpgsql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-- compare the prediction and actual values
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.misclassification_avg(
tbl_prediction VARCHAR, -- predicted values
tbl_actual VARCHAR,
id_actual VARCHAR,
values_actual VARCHAR,
tbl_error VARCHAR
) RETURNS VOID AS $$
DECLARE
error DOUBLE PRECISION;
old_messages VARCHAR;
BEGIN
old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages');
EXECUTE 'SET client_min_messages TO warning';
EXECUTE '
CREATE TABLE '|| tbl_error ||' AS
SELECT
avg(CASE WHEN '|| tbl_prediction ||'.prediction = ('|| values_actual ||')
THEN 0. ELSE 1. END) as error_rate
FROM
'|| tbl_prediction ||',
'|| tbl_actual ||'
WHERE
'|| tbl_prediction ||'.'|| id_actual ||' = '|| tbl_actual ||'.'|| id_actual;
EXECUTE 'SET client_min_messages TO ' || old_messages;
END;
$$ LANGUAGE plpgsql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
------------------------------------------------------------------------
/**
* @brief A prediction function for logistic regression
* The result is stored in the table of tbl_predict
*
* This function can be used together with cross-validation
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_predict(
tbl_model VARCHAR,
tbl_newdata VARCHAR,
col_ind_var VARCHAR,
col_id VARCHAR,
tbl_predict VARCHAR
) RETURNS VOID AS $$
PythonFunction(validation, cross_validation, cv_logregr_predict)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
/**
* @brief Metric function for logistic regression
*
* @param coef Logistic fitting coefficients. Note: MADlib logregr_train
* function, unlike elastic_net, does not produce a separate intercept term.
* @param col_ind Independent variable, an array
* @param col_dep Dependent variable
*
* returns 1 if the prediction is the same as col_dep, otherwise 0
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_accuracy(
coef DOUBLE PRECISION[],
col_ind DOUBLE PRECISION[],
col_dep BOOLEAN
) RETURNS INTEGER AS $$
DECLARE
logregr_result BOOLEAN;
BEGIN
logregr_result := MADLIB_SCHEMA.logregr_predict(coef, col_ind);
RETURN (logregr_result = col_dep)::INTEGER;
END;
$$ LANGUAGE plpgsql
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
/**
* @brief Metric function for logistic regression
*
* It computes the percentage of correct predictions.
* The result is stored in the table of tbl_accuracy
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_accuracy(
tbl_predict VARCHAR,
tbl_source VARCHAR,
col_id VARCHAR,
col_dep_var VARCHAR,
tbl_accuracy VARCHAR
) RETURNS VOID AS $$
PythonFunction(validation, cross_validation, cv_logregr_accuracy)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');