blob: d38b060dd41fb0addfbfe9d07cfe46ef6bb58078 [file] [log] [blame]
<!-- HTML header for doxygen 1.8.4-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen 1.8.13"/>
<meta name="keywords" content="madlib,postgres,greenplum,machine learning,data mining,deep learning,ensemble methods,data science,market basket analysis,affinity analysis,pca,lda,regression,elastic net,huber white,proportional hazards,k-means,latent dirichlet allocation,bayes,support vector machines,svm"/>
<title>MADlib: Cross Validation</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="jquery.js"></script>
<script type="text/javascript" src="dynsections.js"></script>
<link href="navtree.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="resize.js"></script>
<script type="text/javascript" src="navtreedata.js"></script>
<script type="text/javascript" src="navtree.js"></script>
<script type="text/javascript">
$(document).ready(initResizable);
</script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<script type="text/javascript">
$(document).ready(function() { init_search(); });
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js", "TeX/AMSmath.js", "TeX/AMSsymbols.js"],
jax: ["input/TeX","output/HTML-CSS"],
});
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
<!-- hack in the navigation tree -->
<script type="text/javascript" src="eigen_navtree_hacks.js"></script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
<link href="madlib_extra.css" rel="stylesheet" type="text/css"/>
<!-- google analytics -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','//www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-45382226-1', 'madlib.apache.org');
ga('send', 'pageview');
</script>
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<td id="projectlogo"><a href="http://madlib.apache.org"><img alt="Logo" src="madlib.png" height="50" style="padding-left:0.5em;" border="0"/ ></a></td>
<td style="padding-left: 0.5em;">
<div id="projectname">
<span id="projectnumber">1.20.0</span>
</div>
<div id="projectbrief">User Documentation for Apache MADlib</div>
</td>
<td> <div id="MSearchBox" class="MSearchBoxInactive">
<span class="left">
<img id="MSearchSelect" src="search/mag_sel.png"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
alt=""/>
<input type="text" id="MSearchField" value="Search" accesskey="S"
onfocus="searchBox.OnSearchFieldFocus(true)"
onblur="searchBox.OnSearchFieldFocus(false)"
onkeyup="searchBox.OnSearchFieldChange(event)"/>
</span><span class="right">
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
</span>
</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.8.13 -->
<script type="text/javascript">
var searchBox = new SearchBox("searchBox", "search",false,'Search');
</script>
</div><!-- top -->
<div id="side-nav" class="ui-resizable side-nav-resizable">
<div id="nav-tree">
<div id="nav-tree-contents">
<div id="nav-sync" class="sync"></div>
</div>
</div>
<div id="splitbar" style="-moz-user-select:none;"
class="ui-resizable-handle">
</div>
</div>
<script type="text/javascript">
$(document).ready(function(){initNavTree('group__grp__validation.html','');});
</script>
<div id="doc-content">
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
</div>
<!-- iframe showing the search results (closed by default) -->
<div id="MSearchResultsWindow">
<iframe src="javascript:void(0)" frameborder="0"
name="MSearchResults" id="MSearchResults">
</iframe>
</div>
<div class="header">
<div class="headertitle">
<div class="title">Cross Validation<div class="ingroups"><a class="el" href="group__grp__mdl.html">Model Selection</a></div></div> </div>
</div><!--header-->
<div class="contents">
<div class="toc"><b>Contents</b> <ul>
<li>
<a href="#cvfunction">Cross-Validation Function</a> </li>
<li>
<a href="#examples">Examples</a> </li>
<li>
<a href="#notes">Notes</a> </li>
<li>
<a href="#background">Technical Background</a> </li>
<li>
<a href="#related">Related Topics</a> </li>
</ul>
</div><p>Estimates the fit of a predictive model given a data set and specifications for the training, prediction, and error estimation functions.</p>
<p>Cross validation, sometimes called rotation estimation, is a technique for assessing how the results of a statistical analysis will generalize to an independent data set. It is mainly used in settings where the goal is prediction, and you want to estimate how accurately a predictive model will perform in practice.</p>
<p>The cross-validation function provided by this module is very flexible and can work with algorithms you want to cross validate, including algorithms you write yourself. Among the inputs to the cross-validation function are specifications of the modelling, prediction, and error metric functions. These three-part specifications include the name of the function, an array of arguments to pass to the function, and an array of the data types of the arguments. This makes it possible to use functions from other MADlib modules or user-defined functions that you supply.</p>
<ul>
<li>The modelling (training) function takes in a given data set with independent and dependent variables and produces a model, which is stored in an output table.</li>
<li>The prediction function takes in the model generated by the modelling function and a different data set with independent variables, and produces a prediction of the dependent variables based on the model, which is stored in an output table. The prediction function should take a unique ID column name in the data table as one of the inputs, so that the prediction result can be compared with the validation values. Note: Prediction function in some MADlib modules do not save results into an output table. These prediction functions are not suitable for this cross-validation module.</li>
<li>The error metric function compares the prediction results with the known values of the dependent variables in the data set that was fed into the prediction function. It computes the error metric using the specified error metric function, and stores the results in a table.</li>
</ul>
<p>Other inputs include the output table name, k value for the k-fold cross validation, and how many folds to try. For example, you can choose to run a simple validation instead of a full cross validation.</p>
<p><a class="anchor" id="cvfunction"></a></p><dl class="section user"><dt>Cross-Validation Function</dt><dd></dd></dl>
<pre class="syntax">
cross_validation_general( modelling_func,
modelling_params,
modelling_params_type,
param_explored,
explore_values,
predict_func,
predict_params,
predict_params_type,
metric_func,
metric_params,
metric_params_type,
data_tbl,
data_id,
id_is_random,
validation_result,
data_cols,
fold_num
)</pre><p> <b>Arguments</b> </p><dl class="arglist">
<dt>modelling_func </dt>
<dd><p class="startdd">VARCHAR. The name of the function that trains the model.</p>
<p class="enddd"></p>
</dd>
<dt>modelling_params </dt>
<dd><p class="startdd">VARCHAR[]. An array of parameters to supply to the modelling function.</p>
<p class="enddd"></p>
</dd>
<dt>modelling_params_type </dt>
<dd><p class="startdd">VARCHAR[]. An array of data type names for each of the parameters supplied to the modelling function.</p>
<p class="enddd"></p>
</dd>
<dt>param_explored </dt>
<dd><p class="startdd">VARCHAR. The name of the parameter that will be checked to find the optimum value. The name must appear in the <em>modelling_params</em> array.</p>
<p class="enddd"></p>
</dd>
<dt>explore_values </dt>
<dd><p class="startdd">VARCHAR. The name of the parameter whose values are to be studied.</p>
<p class="enddd"></p>
</dd>
<dt>predict_func </dt>
<dd><p class="startdd">VARCHAR. The name of the prediction function.</p>
<p class="enddd"></p>
</dd>
<dt>predict_params </dt>
<dd><p class="startdd">VARCHAR[]. An array of parameters to supply to the prediction function.</p>
<p class="enddd"></p>
</dd>
<dt>predict_params_type </dt>
<dd><p class="startdd">VARCHAR[]. An array of data type names for each of the parameters supplied to the prediction function.</p>
<p class="enddd"></p>
</dd>
<dt>metric_func </dt>
<dd><p class="startdd">VARCHAR. The name of the function for measuring errors.</p>
<p class="enddd"></p>
</dd>
<dt>metric_params </dt>
<dd><p class="startdd">VARCHAR[]. An array of parameters to supply to the error metric function.</p>
<p class="enddd"></p>
</dd>
<dt>metric_params_type </dt>
<dd><p class="startdd">VARCHAR[]. An array of data type names for each of the parameters supplied to the metric function.</p>
<p class="enddd"></p>
</dd>
<dt>data_tbl </dt>
<dd><p class="startdd">VARCHAR. The name of the data table that will be split into training and validation parts.</p>
<p class="enddd"></p>
</dd>
<dt>data_id </dt>
<dd><p class="startdd">VARCHAR. The name of the column containing a unique ID associated with each row, or NULL if the table has no such column.</p>
<p>Ideally, the data set has a unique ID for each row so that it is easier to partition the data set into the training part and the validation part. Set the <em>id_is_random</em> argument to inform the cross-validation function whether the ID value is randomly assigned to each row. If it is not randomly assigned, the cross-validation function generates a random ID for each row. </p>
<p class="enddd"></p>
</dd>
<dt>id_is_random </dt>
<dd><p class="startdd">BOOLEAN. TRUE if the provided ID is randomly assigned to each row.</p>
<p class="enddd"></p>
</dd>
<dt>validation_result </dt>
<dd><p class="startdd">VARCHAR. The name of the table to store the output of the cross-validation function. The output table has the following columns: </p><table class="output">
<tr>
<th>param_explored </th><td>The name of the parameter checked to find the optimum value. This is the same name specified in the <em>param_explored</em> argument of the <em><a class="el" href="cross__validation_8sql__in.html#a2a7791b05f51e8748ab7b6ccf328a7e2">cross_validation_general()</a></em> function. </td></tr>
<tr>
<th>average error </th><td>The average of the errors computed by the error metric function. </td></tr>
<tr>
<th>standard deviation of error </th><td>The standard deviation of the errors. </td></tr>
</table>
<p class="enddd"></p>
</dd>
<dt>data_cols </dt>
<dd><p class="startdd">A comma-separated list of names of data columns to use in the calculation. When its value is NULL, the function will automatically figure out all the column names of the data table. This is only used if the <em>data_id</em> argument is NULL, otherwise it is ignored.</p>
<p>If the data set has no unique ID for each row, the cross-validation function copies the data set to a temporary table with a randomly assigned ID column. Setting this argument to the list of independent and dependent variables that are to be used in the calculation minimizes the copying workload by only copying the required data.</p>
<p class="enddd">The new temporary table is dropped after the computation has finished. </p>
</dd>
<dt><em>fold_num</em> </dt>
<dd>INTEGER, default: 10. Value of k. How many folds validation? Each validation uses 1/fold_num fraction of the data for validation. </dd>
</dl>
<p>The parameter arrays for the modelling, prediction and metric functions can include the following special keywords:</p>
<ul>
<li><em>%data%</em> &ndash; The argument position for training/validation data</li>
<li><em>%model%</em> &ndash; The argument position for the output/input of modelling/prediction function</li>
<li><em>%id%</em> &ndash; The argument position of the unique ID column (user-provided or generated by the cross-validation function, as described above)</li>
<li><em>%prediction%</em> &ndash; The argument position for the output/input of prediction/metric function</li>
<li><em>%error%</em> &ndash; The argument position for the output of the metric function</li>
</ul>
<p><b>Note</b>: If the argument <em>explore_values</em> is NULL or has zero length, then the cross-validation function will only run a data folding.</p>
<p><a class="anchor" id="examples"></a></p><dl class="section user"><dt>Examples</dt><dd></dd></dl>
<ol type="1">
<li>Load some sample data: <pre class="example">
DROP TABLE IF EXISTS houses;
CREATE TABLE houses ( id INT,
tax INT,
bedroom INT,
bath FLOAT,
size INT,
lot INT,
zipcode INT,
price INT,
high_priced BOOLEAN
);
INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, high_priced) VALUES
(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301, 'f'::boolean),
(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301, 'f'::boolean),
(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301, 'f'::boolean),
(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301, 'f'::boolean),
(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301, 't'::boolean),
(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301, 'f'::boolean),
(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301, 't'::boolean),
(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301, 't'::boolean),
(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301, 't'::boolean),
(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301, 't'::boolean),
(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301, 'f'::boolean),
(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301, 't'::boolean),
(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301, 't'::boolean),
(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301, 't'::boolean),
(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301, 'f'::boolean),
(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010, 'f'::boolean),
(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010, 't'::boolean),
(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010, 'f'::boolean),
(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),
(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),
(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),
(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),
(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),
(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),
(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),
(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),
(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean),
(28 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),
(29 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),
(30 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),
(31 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),
(32 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),
(33 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),
(34 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),
(35 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),
(36 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean);
</pre></li>
<li>Use the general function to explore lambda values for elastic net. (Note that elastic net also has a built in cross validation function for selecting elastic net control parameter alpha and regularization value lambda.) <pre class="example">
DROP TABLE IF EXISTS houses_cv_results;
SELECT madlib.cross_validation_general(
-- modelling_func
'madlib.elastic_net_train',
-- modelling_params
'{%data%, %model%, price, "array[tax, bath, size]", gaussian, 0.5, lambda, TRUE, NULL, fista,
"{eta = 2, max_stepsize = 2, use_active_set = t}",
NULL, 200, 1e-6}'::varchar[],
-- modelling_params_type
'{varchar, varchar, varchar, varchar, varchar, double precision,
double precision, boolean, varchar, varchar, varchar, varchar,
integer, double precision}'::varchar[],
-- param_explored
'lambda',
-- explore_values
'{0.1, 0.2}'::varchar[],
-- predict_func
'madlib.elastic_net_predict',
-- predict_params
'{%model%, %data%, %id%, %prediction%}'::varchar[],
-- predict_params_type
'{text, text, text, text}'::varchar[],
-- metric_func
'madlib.mse_error',
-- metric_params
'{%prediction%, %data%, %id%, price, %error%}'::varchar[],
-- metric_params_type
'{varchar, varchar, varchar, varchar, varchar}'::varchar[],
-- data_tbl
'houses',
-- data_id
'id',
-- id_is_random
FALSE,
-- validation_result
'houses_cv_results',
-- data_cols
NULL,
-- fold_num
3
);
SELECT * FROM houses_cv_results;
</pre> Results from the lambda values explored: <pre class="result">
lambda | mean_squared_error_avg | mean_squared_error_stddev
--------+------------------------+---------------------------
0.1 | 1094965503.24269 | 411974996.039577
0.2 | 1093350170.40664 | 411072137.632718
(2 rows)
</pre></li>
<li>Here we use the general function to explore maximum number of iterations for logistic regression: <pre class="example">
DROP TABLE IF EXISTS houses_logregr_cv;
SELECT madlib.cross_validation_general(
-- modelling_func
'madlib.logregr_train',
-- modelling_params
'{%data%, %model%, high_priced, "ARRAY[1, bedroom, bath, size]", NULL, max_iter}'::varchar[],
-- modelling_params_type
'{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[],
-- param_explored
'max_iter',
-- explore_values
'{2, 10, 40, 100}'::varchar[],
-- predict_func
'madlib.cv_logregr_predict',
-- predict_params
'{%model%, %data%, "ARRAY[1, bedroom, bath, size]", id, %prediction%}'::varchar[],
-- predict_params_type
'{varchar, varchar,varchar,varchar,varchar}'::varchar[],
-- metric_func
'madlib.misclassification_avg',
-- metric_params
'{%prediction%, %data%, id, high_priced, %error%}'::varchar[],
-- metric_params_type
'{varchar, varchar, varchar, varchar, varchar}'::varchar[],
-- data_tbl
'houses',
-- data_id
'id',
-- id_is_random
FALSE,
-- validation_result
'houses_logregr_cv',
-- data_cols
NULL,
-- fold_num
5
);
SELECT * FROM houses_logregr_cv;
</pre> Results from the explored number of iterations: <pre class="result">
max_iter | error_rate_avg | error_rate_stddev
----------+------------------------+--------------------------------------------
2 | 0.19285714285714285714 | 0.1589185390091927774733662830554976076700
10 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213
40 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213
100 | 0.22142857142857142857 | 0.1247446371183784331896638996881001527213
(4 rows)
</pre></li>
</ol>
<p><a class="anchor" id="notes"></a></p><dl class="section user"><dt>Notes</dt><dd></dd></dl>
<p>The lock management parameter <em>max_locks_per_transaction</em>, which usually is set to the default value of 64, limits the number of tables that can be dropped inside a single transaction (the cross-validation function). Thus, the number of different values of <em>param_explored</em> (or the length of the <em>explored_values</em> array) cannot be too large. For 10-fold cross validation, the limit of <code>length(<em>explored_values</em>)</code> is around 40. If the limit is exceeded, you may get an "out of shared memory" error because <em>max_locks_per_transaction</em> is exceeded.</p>
<p>One way to overcome this limitation is to run the cross-validation function multiple times, with each run covering a different region of values of the parameter.</p>
<p>Note that MADlib implements cross-validation functions within certain individual modules, where it is possible to optimize the calculation to avoid dropping tables and prevent exceeding the <em>max_locks_per_transaction</em> limitation. Since module-specific cross-validation functions depend upon the implementation details of the modules to perform the optimization, they will not be as flexible as the generalized cross-validation function provided here.</p>
<p><a class="anchor" id="background"></a></p><dl class="section user"><dt>Technical Background</dt><dd></dd></dl>
<p>One round of cross validation involves partitioning a sample of data into complementary subsets, performing the analysis on one subset (called the training set), and validating the analysis on the other subset (called the validation set or test set). To reduce variability, multiple rounds of cross validation are performed using different partitions, and the validation results are averaged over the rounds.</p>
<p>In k-fold cross validation, the original sample is randomly partitioned into k equal sized subsamples. Of the k subsamples, a single subsample is retained as the validation data for testing the model, and the remaining k&minus;1 subsamples are used as training data. The cross-validation process is repeated k times (the folds), with each of the k subsamples used exactly once as the validation data. The k results from the folds can be averaged (or otherwise combined) to produce a single estimation. The advantage of this method over repeated random sub-sampling is that all observations are used for both training and validation, and each observation is used for validation exactly once. 10-fold cross validation is commonly used, but in general k remains an unfixed parameter.</p>
<p><a class="anchor" id="related"></a></p><dl class="section user"><dt>Related Topics</dt><dd></dd></dl>
<p>File <a class="el" href="cross__validation_8sql__in.html" title="SQL functions for cross validation. ">cross_validation.sql_in</a> documenting the SQL functions. </p>
</div><!-- contents -->
</div><!-- doc-content -->
<!-- start footer part -->
<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
<ul>
<li class="footer">Generated on Tue Jul 19 2022 12:19:26 for MADlib by
<a href="http://www.doxygen.org/index.html">
<img class="footer" src="doxygen.png" alt="doxygen"/></a> 1.8.13 </li>
</ul>
</div>
</body>
</html>