blob: 2bc41b7823ab8315e025c396180e748729ebb865 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# XGBoost\n",
"XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable. It implements machine learning algorithms under the Gradient Boosting framework. XGBoost provides a parallel tree boosting (also known as GBDT, GBM) that solve many data science problems in a fast and accurate way. XGBoost was first added in MADlib 1.20.0."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: okislal@madlib'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 6.X\n",
"%sql postgresql://okislal@localhost:6600/madlib"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.20.0, git revision: rc/1.20.0-rc2-6-gb07f7466, cmake configuration time: Fri Jul 29 14:31:52 UTC 2022, build type: RelWithDebInfo, build system: Darwin-20.6.0, C compiler: Clang, C++ compiler: Clang</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.20.0, git revision: rc/1.20.0-rc2-6-gb07f7466, cmake configuration time: Fri Jul 29 14:31:52 UTC 2022, build type: RelWithDebInfo, build system: Darwin-20.6.0, C compiler: Clang, C++ compiler: Clang',)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Load data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The sample data for XGBoost can be downloaded from the examples section of the MADlib documentation. Direct link: https://madlib.apache.org/docs/latest/example/madlib_xgboost_example.sql"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole</th>\n",
" <th>shucked</th>\n",
" <th>viscera</th>\n",
" <th>shell</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2026</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.47</td>\n",
" <td>0.15</td>\n",
" <td>0.9205</td>\n",
" <td>0.381</td>\n",
" <td>0.2435</td>\n",
" <td>0.2675</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1796</td>\n",
" <td>F</td>\n",
" <td>0.58</td>\n",
" <td>0.43</td>\n",
" <td>0.17</td>\n",
" <td>1.48</td>\n",
" <td>0.6535</td>\n",
" <td>0.324</td>\n",
" <td>0.4155</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>829</td>\n",
" <td>I</td>\n",
" <td>0.41</td>\n",
" <td>0.325</td>\n",
" <td>0.1</td>\n",
" <td>0.394</td>\n",
" <td>0.208</td>\n",
" <td>0.0655</td>\n",
" <td>0.106</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3703</td>\n",
" <td>F</td>\n",
" <td>0.665</td>\n",
" <td>0.54</td>\n",
" <td>0.195</td>\n",
" <td>1.764</td>\n",
" <td>0.8505</td>\n",
" <td>0.3615</td>\n",
" <td>0.47</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1665</td>\n",
" <td>I</td>\n",
" <td>0.605</td>\n",
" <td>0.47</td>\n",
" <td>0.145</td>\n",
" <td>0.8025</td>\n",
" <td>0.379</td>\n",
" <td>0.2265</td>\n",
" <td>0.22</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3901</td>\n",
" <td>M</td>\n",
" <td>0.445</td>\n",
" <td>0.345</td>\n",
" <td>0.14</td>\n",
" <td>0.476</td>\n",
" <td>0.2055</td>\n",
" <td>0.1015</td>\n",
" <td>0.1085</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2734</td>\n",
" <td>I</td>\n",
" <td>0.415</td>\n",
" <td>0.335</td>\n",
" <td>0.1</td>\n",
" <td>0.358</td>\n",
" <td>0.169</td>\n",
" <td>0.067</td>\n",
" <td>0.105</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1155</td>\n",
" <td>M</td>\n",
" <td>0.6</td>\n",
" <td>0.455</td>\n",
" <td>0.17</td>\n",
" <td>1.1915</td>\n",
" <td>0.696</td>\n",
" <td>0.2395</td>\n",
" <td>0.24</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3467</td>\n",
" <td>M</td>\n",
" <td>0.64</td>\n",
" <td>0.5</td>\n",
" <td>0.17</td>\n",
" <td>1.4545</td>\n",
" <td>0.642</td>\n",
" <td>0.3575</td>\n",
" <td>0.354</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2433</td>\n",
" <td>F</td>\n",
" <td>0.61</td>\n",
" <td>0.485</td>\n",
" <td>0.165</td>\n",
" <td>1.087</td>\n",
" <td>0.4255</td>\n",
" <td>0.232</td>\n",
" <td>0.38</td>\n",
" <td>11</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2026, u'F', 0.55, 0.47, 0.15, 0.9205, 0.381, 0.2435, 0.2675, 10),\n",
" (1796, u'F', 0.58, 0.43, 0.17, 1.48, 0.6535, 0.324, 0.4155, 10),\n",
" (829, u'I', 0.41, 0.325, 0.1, 0.394, 0.208, 0.0655, 0.106, 6),\n",
" (3703, u'F', 0.665, 0.54, 0.195, 1.764, 0.8505, 0.3615, 0.47, 11),\n",
" (1665, u'I', 0.605, 0.47, 0.145, 0.8025, 0.379, 0.2265, 0.22, 9),\n",
" (3901, u'M', 0.445, 0.345, 0.14, 0.476, 0.2055, 0.1015, 0.1085, 15),\n",
" (2734, u'I', 0.415, 0.335, 0.1, 0.358, 0.169, 0.067, 0.105, 7),\n",
" (1155, u'M', 0.6, 0.455, 0.17, 1.1915, 0.696, 0.2395, 0.24, 8),\n",
" (3467, u'M', 0.64, 0.5, 0.17, 1.4545, 0.642, 0.3575, 0.354, 9),\n",
" (2433, u'F', 0.61, 0.485, 0.165, 1.087, 0.4255, 0.232, 0.38, 11)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"SELECT * FROM abalone LIMIT 10;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Run a single XGBoost training\n",
"Note that the function collates the data into a single segment and runs the xgboost python process on that machine."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>features</th>\n",
" <th>importance</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>fscore</th>\n",
" <th>support</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n",
" <td>[u'1205', u'1179', u'1115', u'941', u'926', u'711', u'580', u'454']</td>\n",
" <td>[u'0.45390070921985815', u'0.6984615384615385', u'0.4780701754385965']</td>\n",
" <td>[u'0.4866920152091255', u'0.8315018315018315', u'0.36454849498327757']</td>\n",
" <td>[u'0.46972477064220186', u'0.7591973244147157', u'0.413662239089184']</td>\n",
" <td>[u'263.0', u'273.0', u'299.0']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], [u'1205', u'1179', u'1115', u'941', u'926', u'711', u'580', u'454'], [u'0.45390070921985815', u'0.6984615384615385', u'0.4780701754385965'], [u'0.4866920152091255', u'0.8315018315018315', u'0.36454849498327757'], [u'0.46972477064220186', u'0.7591973244147157', u'0.413662239089184'], [u'263.0', u'273.0', u'299.0'])]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS xgb_single_out, xgb_single_out_summary;\n",
"SELECT madlib.xgboost(\n",
" 'abalone', -- Training table\n",
" 'xgb_single_out', -- Grid search results table.\n",
" 'id', -- Id column\n",
" 'sex', -- Class label column\n",
" '*', -- Independent variables \n",
" NULL, -- Columns to exclude from features \n",
" $$ \n",
" {\n",
" 'learning_rate': [0.01], #Regularization on weights (eta). For smaller values, increase n_estimators\n",
" 'max_depth': [9],#Larger values could lead to overfitting\n",
" 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting\n",
" 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting\n",
" 'min_child_weight': [10],#larger values will prevent over-fitting\n",
" 'n_estimators':[100] #More estimators, lesser variance (better fit on test set) \n",
" } \n",
" $$, -- XGBoost grid search parameters\n",
" '', -- Class weights\n",
" 0.8, -- Training set size ratio\n",
" NULL -- Variable used to do the test/train split.\n",
");\n",
"\n",
"SELECT features, importance, precision, recall, fscore, support FROM xgb_single_out_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Run XGBoost Prediction"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex_predicted</th>\n",
" <th>sex_proba_predicted</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>I</td>\n",
" <td>[0.180475369096, 0.575919687748, 0.243604928255]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>I</td>\n",
" <td>[0.27669274807, 0.44246467948, 0.280842572451]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>[0.319970279932, 0.313613921404, 0.366415828466]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>[0.384111016989, 0.266917943954, 0.348971098661]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>[0.344503968954, 0.315709024668, 0.339786976576]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>F</td>\n",
" <td>[0.401963979006, 0.242080762982, 0.355955272913]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>I</td>\n",
" <td>[0.315914690495, 0.363648235798, 0.32043710351]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>I</td>\n",
" <td>[0.184259131551, 0.606196165085, 0.209544733167]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>M</td>\n",
" <td>[0.27689999342, 0.361068278551, 0.362031728029]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>F</td>\n",
" <td>[0.367550551891, 0.345346838236, 0.287102639675]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2, u'I', [0.180475369096, 0.575919687748, 0.243604928255]),\n",
" (3, u'I', [0.27669274807, 0.44246467948, 0.280842572451]),\n",
" (4, u'M', [0.319970279932, 0.313613921404, 0.366415828466]),\n",
" (7, u'F', [0.384111016989, 0.266917943954, 0.348971098661]),\n",
" (8, u'F', [0.344503968954, 0.315709024668, 0.339786976576]),\n",
" (16, u'F', [0.401963979006, 0.242080762982, 0.355955272913]),\n",
" (18, u'I', [0.315914690495, 0.363648235798, 0.32043710351]),\n",
" (19, u'I', [0.184259131551, 0.606196165085, 0.209544733167]),\n",
" (22, u'M', [0.27689999342, 0.361068278551, 0.362031728029]),\n",
" (24, u'F', [0.367550551891, 0.345346838236, 0.287102639675])]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS xgb_single_score_out, xgb_single_score_out_metrics, xgb_single_score_out_roc_curve;\n",
"\n",
"SELECT madlib.xgboost_predict(\n",
" 'abalone', -- test_table\n",
" 'xgb_single_out', -- model_table\n",
" 'xgb_single_score_out', -- predict_output_table\n",
" 'id', -- id_column\n",
" 'sex', -- class_label\n",
" 1 -- model_filters\n",
");\n",
"\n",
"SELECT * FROM xgb_single_score_out LIMIT 10;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Run XGBoost with grid search\n",
"The parameter options are combined to form a grid and explored in parallel by running distinct xgboost processes in different segments in parallel. The following example will generate 4 configurations to test by combining 'learning_rate': [0.01,0.1] and 'max_depth': [9,12]."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"Done.\n",
"1 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>features</th>\n",
" <th>params</th>\n",
" <th>importance</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>fscore</th>\n",
" <th>support</th>\n",
" <th>params_index</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n",
" <td>('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')</td>\n",
" <td>[u'1294', u'1183', u'1069', u'974', u'900', u'717', u'608', u'490']</td>\n",
" <td>[u'0.48148148148148145', u'0.6883561643835616', u'0.47619047619047616']</td>\n",
" <td>[u'0.4642857142857143', u'0.788235294117647', u'0.43333333333333335']</td>\n",
" <td>[u'0.4727272727272727', u'0.7349177330895795', u'0.4537521815008726']</td>\n",
" <td>[u'280.0', u'255.0', u'300.0']</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n",
" <td>('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')</td>\n",
" <td>[u'953', u'882', u'872', u'848', u'579', u'500', u'454', u'429']</td>\n",
" <td>[u'0.4259927797833935', u'0.7080536912751678', u'0.47307692307692306']</td>\n",
" <td>[u'0.44696969696969696', u'0.7962264150943397', u'0.4019607843137255']</td>\n",
" <td>[u'0.4362292051756007', u'0.74955595026643', u'0.43462897526501765']</td>\n",
" <td>[u'264.0', u'265.0', u'306.0']</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n",
" <td>('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')</td>\n",
" <td>[u'1168', u'1099', u'1069', u'908', u'717', u'534', u'471', u'462']</td>\n",
" <td>[u'0.4007220216606498', u'0.775', u'0.49640287769784175']</td>\n",
" <td>[u'0.4605809128630705', u'0.775', u'0.4394904458598726']</td>\n",
" <td>[u'0.42857142857142855', u'0.775', u'0.46621621621621623']</td>\n",
" <td>[u'241.0', u'280.0', u'314.0']</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n",
" <td>('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')</td>\n",
" <td>[u'1257', u'1211', u'1105', u'904', u'867', u'824', u'649', u'400']</td>\n",
" <td>[u'0.40148698884758366', u'0.6488095238095238', u'0.49130434782608695']</td>\n",
" <td>[u'0.45188284518828453', u'0.8352490421455939', u'0.3373134328358209']</td>\n",
" <td>[u'0.4251968503937008', u'0.7303182579564489', u'0.4']</td>\n",
" <td>[u'239.0', u'261.0', u'335.0']</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')\", [u'1294', u'1183', u'1069', u'974', u'900', u'717', u'608', u'490'], [u'0.48148148148148145', u'0.6883561643835616', u'0.47619047619047616'], [u'0.4642857142857143', u'0.788235294117647', u'0.43333333333333335'], [u'0.4727272727272727', u'0.7349177330895795', u'0.4537521815008726'], [u'280.0', u'255.0', u'300.0'], 2),\n",
" ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')\", [u'953', u'882', u'872', u'848', u'579', u'500', u'454', u'429'], [u'0.4259927797833935', u'0.7080536912751678', u'0.47307692307692306'], [u'0.44696969696969696', u'0.7962264150943397', u'0.4019607843137255'], [u'0.4362292051756007', u'0.74955595026643', u'0.43462897526501765'], [u'264.0', u'265.0', u'306.0'], 3),\n",
" ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')\", [u'1168', u'1099', u'1069', u'908', u'717', u'534', u'471', u'462'], [u'0.4007220216606498', u'0.775', u'0.49640287769784175'], [u'0.4605809128630705', u'0.775', u'0.4394904458598726'], [u'0.42857142857142855', u'0.775', u'0.46621621621621623'], [u'241.0', u'280.0', u'314.0'], 4),\n",
" ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')\", [u'1257', u'1211', u'1105', u'904', u'867', u'824', u'649', u'400'], [u'0.40148698884758366', u'0.6488095238095238', u'0.49130434782608695'], [u'0.45188284518828453', u'0.8352490421455939', u'0.3373134328358209'], [u'0.4251968503937008', u'0.7303182579564489', u'0.4'], [u'239.0', u'261.0', u'335.0'], 1)]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS xgb_grid_out, xgb_grid_out_summary;\n",
"\n",
"SELECT xgboost(\n",
" 'abalone', -- Training table\n",
" 'xgb_grid_out', -- Grid search results table.\n",
" 'id', -- Id column\n",
" 'sex', -- Class label column\n",
" '*', -- Independent variables\n",
" NULL, -- Columns to exclude from features\n",
" $$\n",
" {\n",
" 'learning_rate': [0.01,0.1], #Regularization on weights (eta). For smaller values, increase n_estimators\n",
" 'max_depth': [9,12],#Larger values could lead to overfitting\n",
" 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting\n",
" 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting\n",
" 'min_child_weight': [10],#larger values will prevent over-fitting\n",
" 'n_estimators':[100] #More estimators, lesser variance (better fit on test set)\n",
" }\n",
" $$, -- XGBoost grid search parameters\n",
" '', -- Class weights\n",
" 0.8, -- Training set size ratio\n",
" NULL -- Variable used to do the test/train split.\n",
");\n",
"\n",
"SELECT features, params, importance, precision, recall, fscore, support, params_index FROM xgb_grid_out_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Run XGBoost Prediction on Grid Output Table\n",
"Let's say we are interested in the model 2 and want to run a prediction using it."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://okislal@localhost:6600/madlib\n",
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex_predicted</th>\n",
" <th>sex_proba_predicted</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>I</td>\n",
" <td>[0.312986373901, 0.34792137146, 0.339092254639]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>I</td>\n",
" <td>[0.337030380964, 0.379457473755, 0.283512145281]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>I</td>\n",
" <td>[0.292645961046, 0.382402688265, 0.324951350689]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>I</td>\n",
" <td>[0.235972866416, 0.479740768671, 0.284286379814]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>M</td>\n",
" <td>[0.3711399436, 0.21823567152, 0.410624355078]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>M</td>\n",
" <td>[0.343350559473, 0.223895892501, 0.432753533125]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>M</td>\n",
" <td>[0.359976351261, 0.246053755283, 0.393969863653]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>F</td>\n",
" <td>[0.437169611454, 0.199478805065, 0.363351553679]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>F</td>\n",
" <td>[0.516167163849, 0.170660674572, 0.313172131777]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>I</td>\n",
" <td>[0.252817928791, 0.48461201787, 0.262570023537]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'I', [0.312986373901, 0.34792137146, 0.339092254639]),\n",
" (12, u'I', [0.337030380964, 0.379457473755, 0.283512145281]),\n",
" (15, u'I', [0.292645961046, 0.382402688265, 0.324951350689]),\n",
" (20, u'I', [0.235972866416, 0.479740768671, 0.284286379814]),\n",
" (23, u'M', [0.3711399436, 0.21823567152, 0.410624355078]),\n",
" (26, u'M', [0.343350559473, 0.223895892501, 0.432753533125]),\n",
" (30, u'M', [0.359976351261, 0.246053755283, 0.393969863653]),\n",
" (31, u'F', [0.437169611454, 0.199478805065, 0.363351553679]),\n",
" (35, u'F', [0.516167163849, 0.170660674572, 0.313172131777]),\n",
" (36, u'I', [0.252817928791, 0.48461201787, 0.262570023537])]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"\n",
"DROP TABLE IF EXISTS xgb_grid_score_out, xgb_grid_score_out_metrics, xgb_grid_score_out_roc_curve;\n",
"\n",
"SELECT madlib.xgboost_predict(\n",
" 'abalone', -- test_table\n",
" 'xgb_grid_out', -- model_table\n",
" 'xgb_grid_score_out', -- predict_output_table\n",
" 'id', -- id_column\n",
" 'sex', -- class_label\n",
" 2 -- model_filters\n",
");\n",
"SELECT * FROM xgb_grid_score_out LIMIT 10;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.17"
}
},
"nbformat": 4,
"nbformat_minor": 1
}