| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# XGBoost\n", |
| "XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable. It implements machine learning algorithms under the Gradient Boosting framework. XGBoost provides a parallel tree boosting (also known as GBDT, GBM) that solve many data science problems in a fast and accurate way. XGBoost was first added in MADlib 1.20.0." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%load_ext sql" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "u'Connected: okislal@madlib'" |
| ] |
| }, |
| "execution_count": 4, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Greenplum Database 6.X\n", |
| "%sql postgresql://okislal@localhost:6600/madlib" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>version</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>MADlib version: 1.20.0, git revision: rc/1.20.0-rc2-6-gb07f7466, cmake configuration time: Fri Jul 29 14:31:52 UTC 2022, build type: RelWithDebInfo, build system: Darwin-20.6.0, C compiler: Clang, C++ compiler: Clang</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'MADlib version: 1.20.0, git revision: rc/1.20.0-rc2-6-gb07f7466, cmake configuration time: Fri Jul 29 14:31:52 UTC 2022, build type: RelWithDebInfo, build system: Darwin-20.6.0, C compiler: Clang, C++ compiler: Clang',)]" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%sql select madlib.version();\n", |
| "#%sql select version();" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 1. Load data" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "The sample data for XGBoost can be downloaded from the examples section of the MADlib documentation. Direct link: https://madlib.apache.org/docs/latest/example/madlib_xgboost_example.sql" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>sex</th>\n", |
| " <th>length</th>\n", |
| " <th>diameter</th>\n", |
| " <th>height</th>\n", |
| " <th>whole</th>\n", |
| " <th>shucked</th>\n", |
| " <th>viscera</th>\n", |
| " <th>shell</th>\n", |
| " <th>rings</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2026</td>\n", |
| " <td>F</td>\n", |
| " <td>0.55</td>\n", |
| " <td>0.47</td>\n", |
| " <td>0.15</td>\n", |
| " <td>0.9205</td>\n", |
| " <td>0.381</td>\n", |
| " <td>0.2435</td>\n", |
| " <td>0.2675</td>\n", |
| " <td>10</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1796</td>\n", |
| " <td>F</td>\n", |
| " <td>0.58</td>\n", |
| " <td>0.43</td>\n", |
| " <td>0.17</td>\n", |
| " <td>1.48</td>\n", |
| " <td>0.6535</td>\n", |
| " <td>0.324</td>\n", |
| " <td>0.4155</td>\n", |
| " <td>10</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>829</td>\n", |
| " <td>I</td>\n", |
| " <td>0.41</td>\n", |
| " <td>0.325</td>\n", |
| " <td>0.1</td>\n", |
| " <td>0.394</td>\n", |
| " <td>0.208</td>\n", |
| " <td>0.0655</td>\n", |
| " <td>0.106</td>\n", |
| " <td>6</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3703</td>\n", |
| " <td>F</td>\n", |
| " <td>0.665</td>\n", |
| " <td>0.54</td>\n", |
| " <td>0.195</td>\n", |
| " <td>1.764</td>\n", |
| " <td>0.8505</td>\n", |
| " <td>0.3615</td>\n", |
| " <td>0.47</td>\n", |
| " <td>11</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1665</td>\n", |
| " <td>I</td>\n", |
| " <td>0.605</td>\n", |
| " <td>0.47</td>\n", |
| " <td>0.145</td>\n", |
| " <td>0.8025</td>\n", |
| " <td>0.379</td>\n", |
| " <td>0.2265</td>\n", |
| " <td>0.22</td>\n", |
| " <td>9</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3901</td>\n", |
| " <td>M</td>\n", |
| " <td>0.445</td>\n", |
| " <td>0.345</td>\n", |
| " <td>0.14</td>\n", |
| " <td>0.476</td>\n", |
| " <td>0.2055</td>\n", |
| " <td>0.1015</td>\n", |
| " <td>0.1085</td>\n", |
| " <td>15</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2734</td>\n", |
| " <td>I</td>\n", |
| " <td>0.415</td>\n", |
| " <td>0.335</td>\n", |
| " <td>0.1</td>\n", |
| " <td>0.358</td>\n", |
| " <td>0.169</td>\n", |
| " <td>0.067</td>\n", |
| " <td>0.105</td>\n", |
| " <td>7</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1155</td>\n", |
| " <td>M</td>\n", |
| " <td>0.6</td>\n", |
| " <td>0.455</td>\n", |
| " <td>0.17</td>\n", |
| " <td>1.1915</td>\n", |
| " <td>0.696</td>\n", |
| " <td>0.2395</td>\n", |
| " <td>0.24</td>\n", |
| " <td>8</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3467</td>\n", |
| " <td>M</td>\n", |
| " <td>0.64</td>\n", |
| " <td>0.5</td>\n", |
| " <td>0.17</td>\n", |
| " <td>1.4545</td>\n", |
| " <td>0.642</td>\n", |
| " <td>0.3575</td>\n", |
| " <td>0.354</td>\n", |
| " <td>9</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2433</td>\n", |
| " <td>F</td>\n", |
| " <td>0.61</td>\n", |
| " <td>0.485</td>\n", |
| " <td>0.165</td>\n", |
| " <td>1.087</td>\n", |
| " <td>0.4255</td>\n", |
| " <td>0.232</td>\n", |
| " <td>0.38</td>\n", |
| " <td>11</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(2026, u'F', 0.55, 0.47, 0.15, 0.9205, 0.381, 0.2435, 0.2675, 10),\n", |
| " (1796, u'F', 0.58, 0.43, 0.17, 1.48, 0.6535, 0.324, 0.4155, 10),\n", |
| " (829, u'I', 0.41, 0.325, 0.1, 0.394, 0.208, 0.0655, 0.106, 6),\n", |
| " (3703, u'F', 0.665, 0.54, 0.195, 1.764, 0.8505, 0.3615, 0.47, 11),\n", |
| " (1665, u'I', 0.605, 0.47, 0.145, 0.8025, 0.379, 0.2265, 0.22, 9),\n", |
| " (3901, u'M', 0.445, 0.345, 0.14, 0.476, 0.2055, 0.1015, 0.1085, 15),\n", |
| " (2734, u'I', 0.415, 0.335, 0.1, 0.358, 0.169, 0.067, 0.105, 7),\n", |
| " (1155, u'M', 0.6, 0.455, 0.17, 1.1915, 0.696, 0.2395, 0.24, 8),\n", |
| " (3467, u'M', 0.64, 0.5, 0.17, 1.4545, 0.642, 0.3575, 0.354, 9),\n", |
| " (2433, u'F', 0.61, 0.485, 0.165, 1.087, 0.4255, 0.232, 0.38, 11)]" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql \n", |
| "SELECT * FROM abalone LIMIT 10;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 2. Run a single XGBoost training\n", |
| "Note that the function collates the data into a single segment and runs the xgboost python process on that machine." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "Done.\n", |
| "1 rows affected.\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>features</th>\n", |
| " <th>importance</th>\n", |
| " <th>precision</th>\n", |
| " <th>recall</th>\n", |
| " <th>fscore</th>\n", |
| " <th>support</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n", |
| " <td>[u'1205', u'1179', u'1115', u'941', u'926', u'711', u'580', u'454']</td>\n", |
| " <td>[u'0.45390070921985815', u'0.6984615384615385', u'0.4780701754385965']</td>\n", |
| " <td>[u'0.4866920152091255', u'0.8315018315018315', u'0.36454849498327757']</td>\n", |
| " <td>[u'0.46972477064220186', u'0.7591973244147157', u'0.413662239089184']</td>\n", |
| " <td>[u'263.0', u'273.0', u'299.0']</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], [u'1205', u'1179', u'1115', u'941', u'926', u'711', u'580', u'454'], [u'0.45390070921985815', u'0.6984615384615385', u'0.4780701754385965'], [u'0.4866920152091255', u'0.8315018315018315', u'0.36454849498327757'], [u'0.46972477064220186', u'0.7591973244147157', u'0.413662239089184'], [u'263.0', u'273.0', u'299.0'])]" |
| ] |
| }, |
| "execution_count": 17, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS xgb_single_out, xgb_single_out_summary;\n", |
| "SELECT madlib.xgboost(\n", |
| " 'abalone', -- Training table\n", |
| " 'xgb_single_out', -- Grid search results table.\n", |
| " 'id', -- Id column\n", |
| " 'sex', -- Class label column\n", |
| " '*', -- Independent variables \n", |
| " NULL, -- Columns to exclude from features \n", |
| " $$ \n", |
| " {\n", |
| " 'learning_rate': [0.01], #Regularization on weights (eta). For smaller values, increase n_estimators\n", |
| " 'max_depth': [9],#Larger values could lead to overfitting\n", |
| " 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting\n", |
| " 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting\n", |
| " 'min_child_weight': [10],#larger values will prevent over-fitting\n", |
| " 'n_estimators':[100] #More estimators, lesser variance (better fit on test set) \n", |
| " } \n", |
| " $$, -- XGBoost grid search parameters\n", |
| " '', -- Class weights\n", |
| " 0.8, -- Training set size ratio\n", |
| " NULL -- Variable used to do the test/train split.\n", |
| ");\n", |
| "\n", |
| "SELECT features, importance, precision, recall, fscore, support FROM xgb_single_out_summary;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 3. Run XGBoost Prediction" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 25, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "Done.\n", |
| "1 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>sex_predicted</th>\n", |
| " <th>sex_proba_predicted</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.180475369096, 0.575919687748, 0.243604928255]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.27669274807, 0.44246467948, 0.280842572451]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>M</td>\n", |
| " <td>[0.319970279932, 0.313613921404, 0.366415828466]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.384111016989, 0.266917943954, 0.348971098661]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.344503968954, 0.315709024668, 0.339786976576]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.401963979006, 0.242080762982, 0.355955272913]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.315914690495, 0.363648235798, 0.32043710351]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.184259131551, 0.606196165085, 0.209544733167]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>M</td>\n", |
| " <td>[0.27689999342, 0.361068278551, 0.362031728029]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.367550551891, 0.345346838236, 0.287102639675]</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(2, u'I', [0.180475369096, 0.575919687748, 0.243604928255]),\n", |
| " (3, u'I', [0.27669274807, 0.44246467948, 0.280842572451]),\n", |
| " (4, u'M', [0.319970279932, 0.313613921404, 0.366415828466]),\n", |
| " (7, u'F', [0.384111016989, 0.266917943954, 0.348971098661]),\n", |
| " (8, u'F', [0.344503968954, 0.315709024668, 0.339786976576]),\n", |
| " (16, u'F', [0.401963979006, 0.242080762982, 0.355955272913]),\n", |
| " (18, u'I', [0.315914690495, 0.363648235798, 0.32043710351]),\n", |
| " (19, u'I', [0.184259131551, 0.606196165085, 0.209544733167]),\n", |
| " (22, u'M', [0.27689999342, 0.361068278551, 0.362031728029]),\n", |
| " (24, u'F', [0.367550551891, 0.345346838236, 0.287102639675])]" |
| ] |
| }, |
| "execution_count": 25, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS xgb_single_score_out, xgb_single_score_out_metrics, xgb_single_score_out_roc_curve;\n", |
| "\n", |
| "SELECT madlib.xgboost_predict(\n", |
| " 'abalone', -- test_table\n", |
| " 'xgb_single_out', -- model_table\n", |
| " 'xgb_single_score_out', -- predict_output_table\n", |
| " 'id', -- id_column\n", |
| " 'sex', -- class_label\n", |
| " 1 -- model_filters\n", |
| ");\n", |
| "\n", |
| "SELECT * FROM xgb_single_score_out LIMIT 10;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 4. Run XGBoost with grid search\n", |
| "The parameter options are combined to form a grid and explored in parallel by running distinct xgboost processes in different segments in parallel. The following example will generate 4 configurations to test by combining 'learning_rate': [0.01,0.1] and 'max_depth': [9,12]." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 32, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "Done.\n", |
| "1 rows affected.\n", |
| "4 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>features</th>\n", |
| " <th>params</th>\n", |
| " <th>importance</th>\n", |
| " <th>precision</th>\n", |
| " <th>recall</th>\n", |
| " <th>fscore</th>\n", |
| " <th>support</th>\n", |
| " <th>params_index</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n", |
| " <td>('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')</td>\n", |
| " <td>[u'1294', u'1183', u'1069', u'974', u'900', u'717', u'608', u'490']</td>\n", |
| " <td>[u'0.48148148148148145', u'0.6883561643835616', u'0.47619047619047616']</td>\n", |
| " <td>[u'0.4642857142857143', u'0.788235294117647', u'0.43333333333333335']</td>\n", |
| " <td>[u'0.4727272727272727', u'0.7349177330895795', u'0.4537521815008726']</td>\n", |
| " <td>[u'280.0', u'255.0', u'300.0']</td>\n", |
| " <td>2</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n", |
| " <td>('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')</td>\n", |
| " <td>[u'953', u'882', u'872', u'848', u'579', u'500', u'454', u'429']</td>\n", |
| " <td>[u'0.4259927797833935', u'0.7080536912751678', u'0.47307692307692306']</td>\n", |
| " <td>[u'0.44696969696969696', u'0.7962264150943397', u'0.4019607843137255']</td>\n", |
| " <td>[u'0.4362292051756007', u'0.74955595026643', u'0.43462897526501765']</td>\n", |
| " <td>[u'264.0', u'265.0', u'306.0']</td>\n", |
| " <td>3</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n", |
| " <td>('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')</td>\n", |
| " <td>[u'1168', u'1099', u'1069', u'908', u'717', u'534', u'471', u'462']</td>\n", |
| " <td>[u'0.4007220216606498', u'0.775', u'0.49640287769784175']</td>\n", |
| " <td>[u'0.4605809128630705', u'0.775', u'0.4394904458598726']</td>\n", |
| " <td>[u'0.42857142857142855', u'0.775', u'0.46621621621621623']</td>\n", |
| " <td>[u'241.0', u'280.0', u'314.0']</td>\n", |
| " <td>4</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings']</td>\n", |
| " <td>('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')</td>\n", |
| " <td>[u'1257', u'1211', u'1105', u'904', u'867', u'824', u'649', u'400']</td>\n", |
| " <td>[u'0.40148698884758366', u'0.6488095238095238', u'0.49130434782608695']</td>\n", |
| " <td>[u'0.45188284518828453', u'0.8352490421455939', u'0.3373134328358209']</td>\n", |
| " <td>[u'0.4251968503937008', u'0.7303182579564489', u'0.4']</td>\n", |
| " <td>[u'239.0', u'261.0', u'335.0']</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')\", [u'1294', u'1183', u'1069', u'974', u'900', u'717', u'608', u'490'], [u'0.48148148148148145', u'0.6883561643835616', u'0.47619047619047616'], [u'0.4642857142857143', u'0.788235294117647', u'0.43333333333333335'], [u'0.4727272727272727', u'0.7349177330895795', u'0.4537521815008726'], [u'280.0', u'255.0', u'300.0'], 2),\n", |
| " ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')\", [u'953', u'882', u'872', u'848', u'579', u'500', u'454', u'429'], [u'0.4259927797833935', u'0.7080536912751678', u'0.47307692307692306'], [u'0.44696969696969696', u'0.7962264150943397', u'0.4019607843137255'], [u'0.4362292051756007', u'0.74955595026643', u'0.43462897526501765'], [u'264.0', u'265.0', u'306.0'], 3),\n", |
| " ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.1', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=12')\", [u'1168', u'1099', u'1069', u'908', u'717', u'534', u'471', u'462'], [u'0.4007220216606498', u'0.775', u'0.49640287769784175'], [u'0.4605809128630705', u'0.775', u'0.4394904458598726'], [u'0.42857142857142855', u'0.775', u'0.46621621621621623'], [u'241.0', u'280.0', u'314.0'], 4),\n", |
| " ([u'length', u'diameter', u'height', u'whole_weight', u'shucked_weight', u'viscera_weight', u'shell_weight', u'rings'], u\"('colsample_bytree=0.85', 'learning_rate=0.01', 'min_child_weight=10', 'n_estimators=100', 'subsample=0.85', 'max_depth=9')\", [u'1257', u'1211', u'1105', u'904', u'867', u'824', u'649', u'400'], [u'0.40148698884758366', u'0.6488095238095238', u'0.49130434782608695'], [u'0.45188284518828453', u'0.8352490421455939', u'0.3373134328358209'], [u'0.4251968503937008', u'0.7303182579564489', u'0.4'], [u'239.0', u'261.0', u'335.0'], 1)]" |
| ] |
| }, |
| "execution_count": 32, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS xgb_grid_out, xgb_grid_out_summary;\n", |
| "\n", |
| "SELECT xgboost(\n", |
| " 'abalone', -- Training table\n", |
| " 'xgb_grid_out', -- Grid search results table.\n", |
| " 'id', -- Id column\n", |
| " 'sex', -- Class label column\n", |
| " '*', -- Independent variables\n", |
| " NULL, -- Columns to exclude from features\n", |
| " $$\n", |
| " {\n", |
| " 'learning_rate': [0.01,0.1], #Regularization on weights (eta). For smaller values, increase n_estimators\n", |
| " 'max_depth': [9,12],#Larger values could lead to overfitting\n", |
| " 'subsample': [0.85],#introduce randomness in samples picked to prevent overfitting\n", |
| " 'colsample_bytree': [0.85],#introduce randomness in features picked to prevent overfitting\n", |
| " 'min_child_weight': [10],#larger values will prevent over-fitting\n", |
| " 'n_estimators':[100] #More estimators, lesser variance (better fit on test set)\n", |
| " }\n", |
| " $$, -- XGBoost grid search parameters\n", |
| " '', -- Class weights\n", |
| " 0.8, -- Training set size ratio\n", |
| " NULL -- Variable used to do the test/train split.\n", |
| ");\n", |
| "\n", |
| "SELECT features, params, importance, precision, recall, fscore, support, params_index FROM xgb_grid_out_summary;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 5. Run XGBoost Prediction on Grid Output Table\n", |
| "Let's say we are interested in the model 2 and want to run a prediction using it." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 31, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| " * postgresql://okislal@localhost:6600/madlib\n", |
| "Done.\n", |
| "1 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>sex_predicted</th>\n", |
| " <th>sex_proba_predicted</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.312986373901, 0.34792137146, 0.339092254639]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.337030380964, 0.379457473755, 0.283512145281]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.292645961046, 0.382402688265, 0.324951350689]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.235972866416, 0.479740768671, 0.284286379814]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>M</td>\n", |
| " <td>[0.3711399436, 0.21823567152, 0.410624355078]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>M</td>\n", |
| " <td>[0.343350559473, 0.223895892501, 0.432753533125]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>30</td>\n", |
| " <td>M</td>\n", |
| " <td>[0.359976351261, 0.246053755283, 0.393969863653]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>31</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.437169611454, 0.199478805065, 0.363351553679]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>35</td>\n", |
| " <td>F</td>\n", |
| " <td>[0.516167163849, 0.170660674572, 0.313172131777]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>36</td>\n", |
| " <td>I</td>\n", |
| " <td>[0.252817928791, 0.48461201787, 0.262570023537]</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, u'I', [0.312986373901, 0.34792137146, 0.339092254639]),\n", |
| " (12, u'I', [0.337030380964, 0.379457473755, 0.283512145281]),\n", |
| " (15, u'I', [0.292645961046, 0.382402688265, 0.324951350689]),\n", |
| " (20, u'I', [0.235972866416, 0.479740768671, 0.284286379814]),\n", |
| " (23, u'M', [0.3711399436, 0.21823567152, 0.410624355078]),\n", |
| " (26, u'M', [0.343350559473, 0.223895892501, 0.432753533125]),\n", |
| " (30, u'M', [0.359976351261, 0.246053755283, 0.393969863653]),\n", |
| " (31, u'F', [0.437169611454, 0.199478805065, 0.363351553679]),\n", |
| " (35, u'F', [0.516167163849, 0.170660674572, 0.313172131777]),\n", |
| " (36, u'I', [0.252817928791, 0.48461201787, 0.262570023537])]" |
| ] |
| }, |
| "execution_count": 31, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "\n", |
| "DROP TABLE IF EXISTS xgb_grid_score_out, xgb_grid_score_out_metrics, xgb_grid_score_out_roc_curve;\n", |
| "\n", |
| "SELECT madlib.xgboost_predict(\n", |
| " 'abalone', -- test_table\n", |
| " 'xgb_grid_out', -- model_table\n", |
| " 'xgb_grid_score_out', -- predict_output_table\n", |
| " 'id', -- id_column\n", |
| " 'sex', -- class_label\n", |
| " 2 -- model_filters\n", |
| ");\n", |
| "SELECT * FROM xgb_grid_score_out LIMIT 10;" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 2", |
| "language": "python", |
| "name": "python2" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 2 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython2", |
| "version": "2.7.17" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 1 |
| } |