| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# Elastic net (MADlib v1.10+)\n", |
| "Demonstrates elastic net, including these updates:\n", |
| "- in MADlib 1.10: grouping and cross validation introduced \n", |
| "- in MADlib 1.13: report negative root mean squared error instead of the negative mean squared error" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "The sql extension is already loaded. To reload it, use:\n", |
| " %reload_ext sql\n" |
| ] |
| } |
| ], |
| "source": [ |
| "%load_ext sql" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "u'Connected: gpadmin@madlib'" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Greenplum Database 5.4.0 on GCP (demo machine)\n", |
| "%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n", |
| " \n", |
| "# PostgreSQL local\n", |
| "#%sql postgresql://fmcquillan@localhost:5432/madlib" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>version</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-gabafa66, cmake configuration time: Wed Jul 11 00:36:05 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-gabafa66, cmake configuration time: Wed Jul 11 00:36:05 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]" |
| ] |
| }, |
| "execution_count": 13, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%sql select madlib.version();\n", |
| "#%sql select version();" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## 1. Create data set\n", |
| "House prices and characteristics." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "Done.\n", |
| "27 rows affected.\n", |
| "27 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>tax</th>\n", |
| " <th>bedroom</th>\n", |
| " <th>bath</th>\n", |
| " <th>price</th>\n", |
| " <th>size</th>\n", |
| " <th>lot</th>\n", |
| " <th>zipcode</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>590</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>50000</td>\n", |
| " <td>770</td>\n", |
| " <td>22100</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>1050</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>85000</td>\n", |
| " <td>1410</td>\n", |
| " <td>12000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>20</td>\n", |
| " <td>3</td>\n", |
| " <td>1.0</td>\n", |
| " <td>22500</td>\n", |
| " <td>1060</td>\n", |
| " <td>3500</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>870</td>\n", |
| " <td>2</td>\n", |
| " <td>2.0</td>\n", |
| " <td>90000</td>\n", |
| " <td>1300</td>\n", |
| " <td>17500</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>1320</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>133000</td>\n", |
| " <td>1500</td>\n", |
| " <td>30000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>1350</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>90500</td>\n", |
| " <td>820</td>\n", |
| " <td>25700</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>2790</td>\n", |
| " <td>3</td>\n", |
| " <td>2.5</td>\n", |
| " <td>260000</td>\n", |
| " <td>2130</td>\n", |
| " <td>25000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>680</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>142500</td>\n", |
| " <td>1170</td>\n", |
| " <td>22000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>1840</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>160000</td>\n", |
| " <td>1500</td>\n", |
| " <td>19000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>3680</td>\n", |
| " <td>4</td>\n", |
| " <td>2.0</td>\n", |
| " <td>240000</td>\n", |
| " <td>2790</td>\n", |
| " <td>20000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>11</td>\n", |
| " <td>1660</td>\n", |
| " <td>3</td>\n", |
| " <td>1.0</td>\n", |
| " <td>87000</td>\n", |
| " <td>1030</td>\n", |
| " <td>17500</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>1620</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>118600</td>\n", |
| " <td>1250</td>\n", |
| " <td>20000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>13</td>\n", |
| " <td>3100</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>140000</td>\n", |
| " <td>1760</td>\n", |
| " <td>38000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>14</td>\n", |
| " <td>2070</td>\n", |
| " <td>2</td>\n", |
| " <td>3.0</td>\n", |
| " <td>148000</td>\n", |
| " <td>1550</td>\n", |
| " <td>14000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>650</td>\n", |
| " <td>3</td>\n", |
| " <td>1.5</td>\n", |
| " <td>65000</td>\n", |
| " <td>1450</td>\n", |
| " <td>12000</td>\n", |
| " <td>94301</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>770</td>\n", |
| " <td>2</td>\n", |
| " <td>2.0</td>\n", |
| " <td>91000</td>\n", |
| " <td>1300</td>\n", |
| " <td>17500</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>17</td>\n", |
| " <td>1220</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>132300</td>\n", |
| " <td>1500</td>\n", |
| " <td>30000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>1150</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>91100</td>\n", |
| " <td>820</td>\n", |
| " <td>25700</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>2690</td>\n", |
| " <td>3</td>\n", |
| " <td>2.5</td>\n", |
| " <td>260011</td>\n", |
| " <td>2130</td>\n", |
| " <td>25000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>780</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>141800</td>\n", |
| " <td>1170</td>\n", |
| " <td>22000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>21</td>\n", |
| " <td>1910</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>160900</td>\n", |
| " <td>1500</td>\n", |
| " <td>19000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>3600</td>\n", |
| " <td>4</td>\n", |
| " <td>2.0</td>\n", |
| " <td>239000</td>\n", |
| " <td>2790</td>\n", |
| " <td>20000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>1600</td>\n", |
| " <td>3</td>\n", |
| " <td>1.0</td>\n", |
| " <td>81010</td>\n", |
| " <td>1030</td>\n", |
| " <td>17500</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>1590</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>117910</td>\n", |
| " <td>1250</td>\n", |
| " <td>20000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>25</td>\n", |
| " <td>3200</td>\n", |
| " <td>3</td>\n", |
| " <td>2.0</td>\n", |
| " <td>141100</td>\n", |
| " <td>1760</td>\n", |
| " <td>38000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>2270</td>\n", |
| " <td>2</td>\n", |
| " <td>3.0</td>\n", |
| " <td>148011</td>\n", |
| " <td>1550</td>\n", |
| " <td>14000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>27</td>\n", |
| " <td>750</td>\n", |
| " <td>3</td>\n", |
| " <td>1.5</td>\n", |
| " <td>66000</td>\n", |
| " <td>1450</td>\n", |
| " <td>12000</td>\n", |
| " <td>76010</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 590, 2, 1.0, 50000, 770, 22100, 94301),\n", |
| " (2, 1050, 3, 2.0, 85000, 1410, 12000, 94301),\n", |
| " (3, 20, 3, 1.0, 22500, 1060, 3500, 94301),\n", |
| " (4, 870, 2, 2.0, 90000, 1300, 17500, 94301),\n", |
| " (5, 1320, 3, 2.0, 133000, 1500, 30000, 94301),\n", |
| " (6, 1350, 2, 1.0, 90500, 820, 25700, 94301),\n", |
| " (7, 2790, 3, 2.5, 260000, 2130, 25000, 94301),\n", |
| " (8, 680, 2, 1.0, 142500, 1170, 22000, 94301),\n", |
| " (9, 1840, 3, 2.0, 160000, 1500, 19000, 94301),\n", |
| " (10, 3680, 4, 2.0, 240000, 2790, 20000, 94301),\n", |
| " (11, 1660, 3, 1.0, 87000, 1030, 17500, 94301),\n", |
| " (12, 1620, 3, 2.0, 118600, 1250, 20000, 94301),\n", |
| " (13, 3100, 3, 2.0, 140000, 1760, 38000, 94301),\n", |
| " (14, 2070, 2, 3.0, 148000, 1550, 14000, 94301),\n", |
| " (15, 650, 3, 1.5, 65000, 1450, 12000, 94301),\n", |
| " (16, 770, 2, 2.0, 91000, 1300, 17500, 76010),\n", |
| " (17, 1220, 3, 2.0, 132300, 1500, 30000, 76010),\n", |
| " (18, 1150, 2, 1.0, 91100, 820, 25700, 76010),\n", |
| " (19, 2690, 3, 2.5, 260011, 2130, 25000, 76010),\n", |
| " (20, 780, 2, 1.0, 141800, 1170, 22000, 76010),\n", |
| " (21, 1910, 3, 2.0, 160900, 1500, 19000, 76010),\n", |
| " (22, 3600, 4, 2.0, 239000, 2790, 20000, 76010),\n", |
| " (23, 1600, 3, 1.0, 81010, 1030, 17500, 76010),\n", |
| " (24, 1590, 3, 2.0, 117910, 1250, 20000, 76010),\n", |
| " (25, 3200, 3, 2.0, 141100, 1760, 38000, 76010),\n", |
| " (26, 2270, 2, 3.0, 148011, 1550, 14000, 76010),\n", |
| " (27, 750, 3, 1.5, 66000, 1450, 12000, 76010)]" |
| ] |
| }, |
| "execution_count": 4, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql \n", |
| "DROP TABLE IF EXISTS houses;\n", |
| "\n", |
| "CREATE TABLE houses ( id INT,\n", |
| " tax INT,\n", |
| " bedroom INT,\n", |
| " bath FLOAT,\n", |
| " price INT,\n", |
| " size INT,\n", |
| " lot INT,\n", |
| " zipcode INT);\n", |
| "\n", |
| "INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode) VALUES\n", |
| "(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301),\n", |
| "(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301),\n", |
| "(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301),\n", |
| "(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301),\n", |
| "(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301),\n", |
| "(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301),\n", |
| "(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301),\n", |
| "(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301),\n", |
| "(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301),\n", |
| "(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301),\n", |
| "(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301),\n", |
| "(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301),\n", |
| "(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301),\n", |
| "(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301),\n", |
| "(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301),\n", |
| "(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010),\n", |
| "(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010),\n", |
| "(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010),\n", |
| "(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010),\n", |
| "(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010),\n", |
| "(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010),\n", |
| "(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010),\n", |
| "(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010),\n", |
| "(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010),\n", |
| "(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010),\n", |
| "(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010),\n", |
| "(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010);\n", |
| "\n", |
| "SELECT * FROM houses ORDER BY id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## 2. Train the model" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>family</th>\n", |
| " <th>features</th>\n", |
| " <th>features_selected</th>\n", |
| " <th>coef_nonzero</th>\n", |
| " <th>coef_all</th>\n", |
| " <th>intercept</th>\n", |
| " <th>log_likelihood</th>\n", |
| " <th>standardize</th>\n", |
| " <th>iteration_run</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>gaussian</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n", |
| " <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n", |
| " <td>-7798.78310728</td>\n", |
| " <td>-512248641.97</td>\n", |
| " <td>True</td>\n", |
| " <td>10000</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [22.7851318679, 10707.9553682, 54.7961166559], [22.7851318679, 10707.9553682, 54.7961166559], -7798.78310728, -512248641.97, True, 10000)]" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS houses_en, houses_en_summary;\n", |
| "SELECT madlib.elastic_net_train( 'houses', -- Source table\n", |
| " 'houses_en', -- Result table\n", |
| " 'price', -- Dependent variable\n", |
| " 'array[tax, bath, size]', -- Independent variable\n", |
| " 'gaussian', -- Regression family\n", |
| " 0.5, -- Alpha value\n", |
| " 0.1, -- Lambda value\n", |
| " TRUE, -- Standardize\n", |
| " NULL, -- Grouping column(s)\n", |
| " 'fista', -- Optimizer\n", |
| " '', -- Optimizer parameters\n", |
| " NULL, -- Excluded columns\n", |
| " 10000, -- Maximum iterations\n", |
| " 1e-6 -- Tolerance value\n", |
| " );\n", |
| "SELECT * FROM houses_en;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 3. Prediction\n", |
| "Evaluate residuals." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "27 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>price</th>\n", |
| " <th>predict</th>\n", |
| " <th>residual</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>50000</td>\n", |
| " <td>58545.409888</td>\n", |
| " <td>-8545.40988802</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>85000</td>\n", |
| " <td>114804.040575</td>\n", |
| " <td>-29804.0405752</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>22500</td>\n", |
| " <td>61448.7585535</td>\n", |
| " <td>-38948.7585535</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>90000</td>\n", |
| " <td>104675.144007</td>\n", |
| " <td>-14675.1440069</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>133000</td>\n", |
| " <td>125887.676679</td>\n", |
| " <td>7112.3233214</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>90500</td>\n", |
| " <td>78601.9159404</td>\n", |
| " <td>11898.0840596</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>260000</td>\n", |
| " <td>199257.351702</td>\n", |
| " <td>60742.6482983</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>142500</td>\n", |
| " <td>82514.5184185</td>\n", |
| " <td>59985.4815815</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>160000</td>\n", |
| " <td>137735.94525</td>\n", |
| " <td>22264.0547501</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>240000</td>\n", |
| " <td>250347.578373</td>\n", |
| " <td>-10347.578373</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>11</td>\n", |
| " <td>87000</td>\n", |
| " <td>97172.4913172</td>\n", |
| " <td>-10172.4913172</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>118600</td>\n", |
| " <td>119024.187075</td>\n", |
| " <td>-424.187074993</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>13</td>\n", |
| " <td>140000</td>\n", |
| " <td>180692.201734</td>\n", |
| " <td>-40692.201734</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>14</td>\n", |
| " <td>148000</td>\n", |
| " <td>156424.286781</td>\n", |
| " <td>-8424.28678052</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>65000</td>\n", |
| " <td>102527.85481</td>\n", |
| " <td>-37527.8548102</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>91000</td>\n", |
| " <td>102396.63082</td>\n", |
| " <td>-11396.6308201</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>17</td>\n", |
| " <td>132300</td>\n", |
| " <td>123609.163492</td>\n", |
| " <td>8690.83650819</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>91100</td>\n", |
| " <td>74044.8895668</td>\n", |
| " <td>17055.1104332</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>260011</td>\n", |
| " <td>196978.838515</td>\n", |
| " <td>63032.1614851</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>141800</td>\n", |
| " <td>84793.0316053</td>\n", |
| " <td>57006.9683947</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>21</td>\n", |
| " <td>160900</td>\n", |
| " <td>139330.904481</td>\n", |
| " <td>21569.0955193</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>239000</td>\n", |
| " <td>248524.767824</td>\n", |
| " <td>-9524.76782352</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>81010</td>\n", |
| " <td>95805.3834051</td>\n", |
| " <td>-14795.3834051</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>117910</td>\n", |
| " <td>118340.633119</td>\n", |
| " <td>-430.633118956</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>25</td>\n", |
| " <td>141100</td>\n", |
| " <td>182970.714921</td>\n", |
| " <td>-41870.7149208</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>148011</td>\n", |
| " <td>160981.313154</td>\n", |
| " <td>-12970.3131541</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>27</td>\n", |
| " <td>66000</td>\n", |
| " <td>104806.367997</td>\n", |
| " <td>-38806.367997</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 50000, 58545.409888024, -8545.409888024),\n", |
| " (2, 85000, 114804.040575234, -29804.040575234),\n", |
| " (3, 22500, 61448.758553532, -38948.758553532),\n", |
| " (4, 90000, 104675.144006863, -14675.144006863),\n", |
| " (5, 133000, 125887.676678598, 7112.323321402),\n", |
| " (6, 90500, 78601.915940423, 11898.084059577),\n", |
| " (7, 260000, 199257.351701728, 60742.648298272),\n", |
| " (8, 142500, 82514.518418495, 59985.481581505),\n", |
| " (9, 160000, 137735.945249906, 22264.054750094),\n", |
| " (10, 240000, 250347.578372953, -10347.578372953),\n", |
| " (11, 87000, 97172.491317211, -10172.491317211),\n", |
| " (12, 118600, 119024.187074993, -424.187074992995),\n", |
| " (13, 140000, 180692.201733994, -40692.201733994),\n", |
| " (14, 148000, 156424.286780518, -8424.28678051801),\n", |
| " (15, 65000, 102527.85481021, -37527.85481021),\n", |
| " (16, 91000, 102396.630820073, -11396.630820073),\n", |
| " (17, 132300, 123609.163491808, 8690.83650819201),\n", |
| " (18, 91100, 74044.889566843, 17055.110433157),\n", |
| " (19, 260011, 196978.838514938, 63032.161485062),\n", |
| " (20, 141800, 84793.031605285, 57006.968394715),\n", |
| " (21, 160900, 139330.904480659, 21569.095519341),\n", |
| " (22, 239000, 248524.767823521, -9524.76782352099),\n", |
| " (23, 81010, 95805.383405137, -14795.383405137),\n", |
| " (24, 117910, 118340.633118956, -430.633118956001),\n", |
| " (25, 141100, 182970.714920784, -41870.714920784),\n", |
| " (26, 148011, 160981.313154098, -12970.313154098),\n", |
| " (27, 66000, 104806.367997, -38806.367997)]" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT id, price, predict, price - predict AS residual\n", |
| "FROM (\n", |
| " SELECT\n", |
| " houses.*,\n", |
| " madlib.elastic_net_gaussian_predict(\n", |
| " m.coef_all, -- Coefficients\n", |
| " m.intercept, -- Intercept\n", |
| " ARRAY[tax,bath,size] -- Features (corresponding to coefficients)\n", |
| " ) AS predict\n", |
| " FROM houses, houses_en m) s\n", |
| "ORDER BY id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 4. Grouping \n", |
| "Group on zip code." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "2 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>zipcode</th>\n", |
| " <th>family</th>\n", |
| " <th>features</th>\n", |
| " <th>features_selected</th>\n", |
| " <th>coef_nonzero</th>\n", |
| " <th>coef_all</th>\n", |
| " <th>intercept</th>\n", |
| " <th>log_likelihood</th>\n", |
| " <th>standardize</th>\n", |
| " <th>iteration_run</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>76010</td>\n", |
| " <td>gaussian</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n", |
| " <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n", |
| " <td>14.7294468096</td>\n", |
| " <td>-525667117.987</td>\n", |
| " <td>True</td>\n", |
| " <td>10000</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>94301</td>\n", |
| " <td>gaussian</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n", |
| " <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n", |
| " <td>-11145.5017384</td>\n", |
| " <td>-520358795.785</td>\n", |
| " <td>True</td>\n", |
| " <td>10000</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(76010, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [14.9802020928, 9133.17041265, 62.8225614522], [14.9802020928, 9133.17041265, 62.8225614522], 14.7294468096, -525667117.987, True, 10000),\n", |
| " (94301, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [27.6945649037, 11509.010807, 49.0945476263], [27.6945649037, 11509.010807, 49.0945476263], -11145.5017384, -520358795.785, True, 10000)]" |
| ] |
| }, |
| "execution_count": 8, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS houses_en1, houses_en1_summary;\n", |
| "SELECT madlib.elastic_net_train( 'houses', -- Source table\n", |
| " 'houses_en1', -- Result table\n", |
| " 'price', -- Dependent variable\n", |
| " 'array[tax, bath, size]', -- Independent variable\n", |
| " 'gaussian', -- Regression family\n", |
| " 0.5, -- Alpha value\n", |
| " 0.1, -- Lambda value\n", |
| " TRUE, -- Standardize\n", |
| " 'zipcode', -- Grouping column(s)\n", |
| " 'fista', -- Optimizer\n", |
| " '', -- Optimizer parameters\n", |
| " NULL, -- Excluded columns\n", |
| " 10000, -- Maximum iterations\n", |
| " 1e-6 -- Tolerance value\n", |
| " );\n", |
| "SELECT * FROM houses_en1;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Prediction function" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "1 rows affected.\n", |
| "27 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>price</th>\n", |
| " <th>prediction</th>\n", |
| " <th>residual</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>50000</td>\n", |
| " <td>54506.104034</td>\n", |
| " <td>-4506.10403403</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>85000</td>\n", |
| " <td>110175.125178</td>\n", |
| " <td>-25175.1251776</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>22500</td>\n", |
| " <td>52957.6208506</td>\n", |
| " <td>-30457.6208506</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>90000</td>\n", |
| " <td>99789.703256</td>\n", |
| " <td>-9789.70325601</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>133000</td>\n", |
| " <td>122071.166988</td>\n", |
| " <td>10928.8330121</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>90500</td>\n", |
| " <td>78008.7007422</td>\n", |
| " <td>12491.2992578</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>260000</td>\n", |
| " <td>199466.247804</td>\n", |
| " <td>60533.7521956</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>142500</td>\n", |
| " <td>76636.4339259</td>\n", |
| " <td>65863.5660741</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>160000</td>\n", |
| " <td>136472.340738</td>\n", |
| " <td>23527.6592621</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>240000</td>\n", |
| " <td>250762.306599</td>\n", |
| " <td>-10762.3065986</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>11</td>\n", |
| " <td>87000</td>\n", |
| " <td>96903.8708638</td>\n", |
| " <td>-9903.87086383</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>118600</td>\n", |
| " <td>118105.899552</td>\n", |
| " <td>494.100447531</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>13</td>\n", |
| " <td>140000</td>\n", |
| " <td>184132.074899</td>\n", |
| " <td>-44132.0748994</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>14</td>\n", |
| " <td>148000</td>\n", |
| " <td>156805.828854</td>\n", |
| " <td>-8805.82885402</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>65000</td>\n", |
| " <td>95306.5757176</td>\n", |
| " <td>-30306.5757176</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>91000</td>\n", |
| " <td>111485.155771</td>\n", |
| " <td>-20485.1557714</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>17</td>\n", |
| " <td>132300</td>\n", |
| " <td>130790.759004</td>\n", |
| " <td>1509.24099637</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>91100</td>\n", |
| " <td>77889.632657</td>\n", |
| " <td>13210.367343</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>260011</td>\n", |
| " <td>196956.455001</td>\n", |
| " <td>63054.5449987</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>141800</td>\n", |
| " <td>94334.8543909</td>\n", |
| " <td>47465.1456091</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>21</td>\n", |
| " <td>160900</td>\n", |
| " <td>141127.098448</td>\n", |
| " <td>19772.9015523</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>239000</td>\n", |
| " <td>247484.744258</td>\n", |
| " <td>-8484.74425783</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>81010</td>\n", |
| " <td>97823.4615037</td>\n", |
| " <td>-16813.4615037</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>117910</td>\n", |
| " <td>120627.793415</td>\n", |
| " <td>-2717.79341491</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>25</td>\n", |
| " <td>141100</td>\n", |
| " <td>176785.425125</td>\n", |
| " <td>-35685.4251249</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>148011</td>\n", |
| " <td>158794.269686</td>\n", |
| " <td>-10783.2696863</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>27</td>\n", |
| " <td>66000</td>\n", |
| " <td>116042.350741</td>\n", |
| " <td>-50042.3507411</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 50000, 54506.104034034, -4506.104034034),\n", |
| " (2, 85000, 110175.125177568, -25175.125177568),\n", |
| " (3, 22500, 52957.620850552, -30457.620850552),\n", |
| " (4, 90000, 99789.703256009, -9789.703256009),\n", |
| " (5, 133000, 122071.166987934, 10928.833012066),\n", |
| " (6, 90500, 78008.700742161, 12491.299257839),\n", |
| " (7, 260000, 199466.247804442, 60533.752195558),\n", |
| " (8, 142500, 76636.433925887, 65863.566074113),\n", |
| " (9, 160000, 136472.340737858, 23527.659262142),\n", |
| " (10, 240000, 250762.306598593, -10762.306598593),\n", |
| " (11, 87000, 96903.870863831, -9903.87086383101),\n", |
| " (12, 118600, 118105.899552469, 494.100447531004),\n", |
| " (13, 140000, 184132.074899358, -44132.074899358),\n", |
| " (14, 148000, 156805.828854024, -8805.828854024),\n", |
| " (15, 65000, 95306.57571764, -30306.57571764),\n", |
| " (16, 91000, 111485.155771426, -20485.1557714256),\n", |
| " (17, 132300, 130790.759003626, 1509.2409963744),\n", |
| " (18, 91100, 77889.6326569836, 13210.3673430164),\n", |
| " (19, 260011, 196956.455001253, 63054.5449987474),\n", |
| " (20, 141800, 94334.8543909176, 47465.1456090824),\n", |
| " (21, 160900, 141127.098447658, 19772.9015523424),\n", |
| " (22, 239000, 247484.744257828, -8484.74425782761),\n", |
| " (23, 81010, 97823.4615037056, -16813.4615037056),\n", |
| " (24, 117910, 120627.793414912, -2717.7934149116),\n", |
| " (25, 141100, 176785.425124942, -35685.4251249416),\n", |
| " (26, 148011, 158794.269686326, -10783.2696863256),\n", |
| " (27, 66000, 116042.350741075, -50042.3507410746)]" |
| ] |
| }, |
| "execution_count": 9, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT madlib.elastic_net_predict(\n", |
| " 'houses_en1', -- Model table\n", |
| " 'houses', -- New source data table\n", |
| " 'id', -- Unique ID associated with each row\n", |
| " 'houses_en1_prediction' -- Table to store prediction result\n", |
| " );\n", |
| "\n", |
| "SELECT houses.id,\n", |
| " houses.price,\n", |
| " houses_en1_prediction.prediction,\n", |
| " houses.price - houses_en1_prediction.prediction AS residual\n", |
| "FROM houses_en1_prediction, houses\n", |
| "WHERE houses.id = houses_en1_prediction.id ORDER BY id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## 5. When coef_nonzero is different from coef_all\n", |
| "Train" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>family</th>\n", |
| " <th>features</th>\n", |
| " <th>features_selected</th>\n", |
| " <th>coef_nonzero</th>\n", |
| " <th>coef_all</th>\n", |
| " <th>intercept</th>\n", |
| " <th>log_likelihood</th>\n", |
| " <th>standardize</th>\n", |
| " <th>iteration_run</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>gaussian</td>\n", |
| " <td>[u'tax', u'bath', u'size']</td>\n", |
| " <td>[u'tax', u'size']</td>\n", |
| " <td>[6.94383308191, 29.7206857861]</td>\n", |
| " <td>[6.94383308191, 0.0, 29.7206857861]</td>\n", |
| " <td>74441.4573381</td>\n", |
| " <td>-1635348584.1</td>\n", |
| " <td>True</td>\n", |
| " <td>173</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'size'], [6.94383308191, 29.7206857861], [6.94383308191, 0.0, 29.7206857861], 74441.4573381, -1635348584.1, True, 173)]" |
| ] |
| }, |
| "execution_count": 10, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS houses_en2, houses_en2_summary;\n", |
| "SELECT madlib.elastic_net_train( 'houses', -- Source table\n", |
| " 'houses_en2', -- Result table\n", |
| " 'price', -- Dependent variable\n", |
| " 'array[tax, bath, size]', -- Independent variable\n", |
| " 'gaussian', -- Regression family\n", |
| " 1, -- Alpha value\n", |
| " 30000, -- Lambda value\n", |
| " TRUE, -- Standardize\n", |
| " NULL, -- Grouping column(s)\n", |
| " 'fista', -- Optimizer\n", |
| " '', -- Optimizer parameters\n", |
| " NULL, -- Excluded columns\n", |
| " 10000, -- Maximum iterations\n", |
| " 1e-6 -- Tolerance value\n", |
| " );\n", |
| "SELECT * FROM houses_en2;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Prediction function with coef_all to evaluate residuals." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "27 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>price</th>\n", |
| " <th>predict</th>\n", |
| " <th>residual</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>50000</td>\n", |
| " <td>101423.246912</td>\n", |
| " <td>-51423.2469117</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>85000</td>\n", |
| " <td>123638.649033</td>\n", |
| " <td>-38638.6490325</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>22500</td>\n", |
| " <td>106084.260933</td>\n", |
| " <td>-83584.260933</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>90000</td>\n", |
| " <td>119119.483641</td>\n", |
| " <td>-29119.4836413</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>133000</td>\n", |
| " <td>128188.345685</td>\n", |
| " <td>4811.65431463</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>90500</td>\n", |
| " <td>108186.594343</td>\n", |
| " <td>-17686.5943433</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>260000</td>\n", |
| " <td>157119.812361</td>\n", |
| " <td>102880.187639</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>142500</td>\n", |
| " <td>113936.466204</td>\n", |
| " <td>28563.5337965</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>160000</td>\n", |
| " <td>131799.138888</td>\n", |
| " <td>28200.861112</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>240000</td>\n", |
| " <td>182915.476423</td>\n", |
| " <td>57084.5235773</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>11</td>\n", |
| " <td>87000</td>\n", |
| " <td>116580.526614</td>\n", |
| " <td>-29580.5266138</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>118600</td>\n", |
| " <td>122841.324163</td>\n", |
| " <td>-4241.32416342</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>13</td>\n", |
| " <td>140000</td>\n", |
| " <td>148275.746876</td>\n", |
| " <td>-8275.74687556</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>14</td>\n", |
| " <td>148000</td>\n", |
| " <td>134882.254786</td>\n", |
| " <td>13117.7452139</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>65000</td>\n", |
| " <td>122049.943231</td>\n", |
| " <td>-57049.9432312</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>91000</td>\n", |
| " <td>118425.100333</td>\n", |
| " <td>-27425.1003331</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>17</td>\n", |
| " <td>132300</td>\n", |
| " <td>127493.962377</td>\n", |
| " <td>4806.03762282</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>91100</td>\n", |
| " <td>106797.827727</td>\n", |
| " <td>-15697.8277269</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>260011</td>\n", |
| " <td>156425.429053</td>\n", |
| " <td>103585.570947</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>141800</td>\n", |
| " <td>114630.849512</td>\n", |
| " <td>27169.1504883</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>21</td>\n", |
| " <td>160900</td>\n", |
| " <td>132285.207204</td>\n", |
| " <td>28614.7927963</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>239000</td>\n", |
| " <td>182359.969776</td>\n", |
| " <td>56640.0302238</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>81010</td>\n", |
| " <td>116163.896629</td>\n", |
| " <td>-35153.8966288</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>117910</td>\n", |
| " <td>122633.009171</td>\n", |
| " <td>-4723.00917096</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>25</td>\n", |
| " <td>141100</td>\n", |
| " <td>148970.130184</td>\n", |
| " <td>-7870.13018375</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>148011</td>\n", |
| " <td>136271.021402</td>\n", |
| " <td>11739.9785975</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>27</td>\n", |
| " <td>66000</td>\n", |
| " <td>122744.326539</td>\n", |
| " <td>-56744.3265394</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 50000, 101423.246911724, -51423.2469117239),\n", |
| " (2, 85000, 123638.649032506, -38638.6490325065),\n", |
| " (3, 22500, 106084.260933004, -83584.2609330042),\n", |
| " (4, 90000, 119119.483641292, -29119.4836412917),\n", |
| " (5, 133000, 128188.345685371, 4811.6543146288),\n", |
| " (6, 90500, 108186.59434328, -17686.5943432805),\n", |
| " (7, 260000, 157119.812361022, 102880.187638978),\n", |
| " (8, 142500, 113936.466203536, 28563.5337964642),\n", |
| " (9, 160000, 131799.138887964, 28200.8611120356),\n", |
| " (10, 240000, 182915.476422748, 57084.5235772522),\n", |
| " (11, 87000, 116580.526613754, -29580.5266137536),\n", |
| " (12, 118600, 122841.324163419, -4241.32416341919),\n", |
| " (13, 140000, 148275.746875557, -8275.746875557),\n", |
| " (14, 148000, 134882.254786109, 13117.7452138913),\n", |
| " (15, 65000, 122049.943231186, -57049.9432311865),\n", |
| " (16, 91000, 118425.100333101, -27425.1003331007),\n", |
| " (17, 132300, 127493.96237718, 4806.03762281981),\n", |
| " (18, 91100, 106797.827726898, -15697.8277268985),\n", |
| " (19, 260011, 156425.429052831, 103585.570947169),\n", |
| " (20, 141800, 114630.849511727, 27169.1504882732),\n", |
| " (21, 160900, 132285.207203698, 28614.7927963019),\n", |
| " (22, 239000, 182359.969776195, 56640.030223805),\n", |
| " (23, 81010, 116163.896628839, -35153.896628839),\n", |
| " (24, 117910, 122633.009170962, -4723.00917096189),\n", |
| " (25, 141100, 148970.130183748, -7870.130183748),\n", |
| " (26, 148011, 136271.021402491, 11739.9785975093),\n", |
| " (27, 66000, 122744.326539377, -56744.3265393775)]" |
| ] |
| }, |
| "execution_count": 11, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT id, price, predict, price - predict AS residual\n", |
| "FROM (\n", |
| " SELECT\n", |
| " houses.*,\n", |
| " madlib.elastic_net_gaussian_predict(\n", |
| " m.coef_all, -- All coefficients\n", |
| " m.intercept, -- Intercept\n", |
| " ARRAY[tax,bath,size] -- All features\n", |
| " ) AS predict\n", |
| " FROM houses, houses_en2 m) s\n", |
| "ORDER BY id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "We can speed up the prediction function with coef_nonzero to evaluate residuals. This requires the user to examine the feature_selected column in the result table to construct the correct set of independent variables to provide to the prediction function." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "27 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>id</th>\n", |
| " <th>price</th>\n", |
| " <th>predict</th>\n", |
| " <th>residual</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>50000</td>\n", |
| " <td>101423.246912</td>\n", |
| " <td>-51423.2469117</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>85000</td>\n", |
| " <td>123638.649033</td>\n", |
| " <td>-38638.6490325</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>22500</td>\n", |
| " <td>106084.260933</td>\n", |
| " <td>-83584.260933</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>90000</td>\n", |
| " <td>119119.483641</td>\n", |
| " <td>-29119.4836413</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>133000</td>\n", |
| " <td>128188.345685</td>\n", |
| " <td>4811.65431463</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>90500</td>\n", |
| " <td>108186.594343</td>\n", |
| " <td>-17686.5943433</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>260000</td>\n", |
| " <td>157119.812361</td>\n", |
| " <td>102880.187639</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>142500</td>\n", |
| " <td>113936.466204</td>\n", |
| " <td>28563.5337965</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>160000</td>\n", |
| " <td>131799.138888</td>\n", |
| " <td>28200.861112</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>240000</td>\n", |
| " <td>182915.476423</td>\n", |
| " <td>57084.5235773</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>11</td>\n", |
| " <td>87000</td>\n", |
| " <td>116580.526614</td>\n", |
| " <td>-29580.5266138</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>12</td>\n", |
| " <td>118600</td>\n", |
| " <td>122841.324163</td>\n", |
| " <td>-4241.32416342</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>13</td>\n", |
| " <td>140000</td>\n", |
| " <td>148275.746876</td>\n", |
| " <td>-8275.74687556</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>14</td>\n", |
| " <td>148000</td>\n", |
| " <td>134882.254786</td>\n", |
| " <td>13117.7452139</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>15</td>\n", |
| " <td>65000</td>\n", |
| " <td>122049.943231</td>\n", |
| " <td>-57049.9432312</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>16</td>\n", |
| " <td>91000</td>\n", |
| " <td>118425.100333</td>\n", |
| " <td>-27425.1003331</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>17</td>\n", |
| " <td>132300</td>\n", |
| " <td>127493.962377</td>\n", |
| " <td>4806.03762282</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>18</td>\n", |
| " <td>91100</td>\n", |
| " <td>106797.827727</td>\n", |
| " <td>-15697.8277269</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>19</td>\n", |
| " <td>260011</td>\n", |
| " <td>156425.429053</td>\n", |
| " <td>103585.570947</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>20</td>\n", |
| " <td>141800</td>\n", |
| " <td>114630.849512</td>\n", |
| " <td>27169.1504883</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>21</td>\n", |
| " <td>160900</td>\n", |
| " <td>132285.207204</td>\n", |
| " <td>28614.7927963</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>22</td>\n", |
| " <td>239000</td>\n", |
| " <td>182359.969776</td>\n", |
| " <td>56640.0302238</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>23</td>\n", |
| " <td>81010</td>\n", |
| " <td>116163.896629</td>\n", |
| " <td>-35153.8966288</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>24</td>\n", |
| " <td>117910</td>\n", |
| " <td>122633.009171</td>\n", |
| " <td>-4723.00917096</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>25</td>\n", |
| " <td>141100</td>\n", |
| " <td>148970.130184</td>\n", |
| " <td>-7870.13018375</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>26</td>\n", |
| " <td>148011</td>\n", |
| " <td>136271.021402</td>\n", |
| " <td>11739.9785975</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>27</td>\n", |
| " <td>66000</td>\n", |
| " <td>122744.326539</td>\n", |
| " <td>-56744.3265394</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 50000, 101423.246911724, -51423.2469117239),\n", |
| " (2, 85000, 123638.649032506, -38638.6490325065),\n", |
| " (3, 22500, 106084.260933004, -83584.2609330042),\n", |
| " (4, 90000, 119119.483641292, -29119.4836412917),\n", |
| " (5, 133000, 128188.345685371, 4811.6543146288),\n", |
| " (6, 90500, 108186.59434328, -17686.5943432805),\n", |
| " (7, 260000, 157119.812361022, 102880.187638978),\n", |
| " (8, 142500, 113936.466203536, 28563.5337964642),\n", |
| " (9, 160000, 131799.138887964, 28200.8611120356),\n", |
| " (10, 240000, 182915.476422748, 57084.5235772522),\n", |
| " (11, 87000, 116580.526613754, -29580.5266137536),\n", |
| " (12, 118600, 122841.324163419, -4241.32416341919),\n", |
| " (13, 140000, 148275.746875557, -8275.746875557),\n", |
| " (14, 148000, 134882.254786109, 13117.7452138913),\n", |
| " (15, 65000, 122049.943231186, -57049.9432311865),\n", |
| " (16, 91000, 118425.100333101, -27425.1003331007),\n", |
| " (17, 132300, 127493.96237718, 4806.03762281981),\n", |
| " (18, 91100, 106797.827726898, -15697.8277268985),\n", |
| " (19, 260011, 156425.429052831, 103585.570947169),\n", |
| " (20, 141800, 114630.849511727, 27169.1504882732),\n", |
| " (21, 160900, 132285.207203698, 28614.7927963019),\n", |
| " (22, 239000, 182359.969776195, 56640.030223805),\n", |
| " (23, 81010, 116163.896628839, -35153.896628839),\n", |
| " (24, 117910, 122633.009170962, -4723.00917096189),\n", |
| " (25, 141100, 148970.130183748, -7870.130183748),\n", |
| " (26, 148011, 136271.021402491, 11739.9785975093),\n", |
| " (27, 66000, 122744.326539377, -56744.3265393775)]" |
| ] |
| }, |
| "execution_count": 12, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT id, price, predict, price - predict AS residual\n", |
| "FROM (\n", |
| " SELECT\n", |
| " houses.*,\n", |
| " madlib.elastic_net_gaussian_predict(\n", |
| " m.coef_nonzero, -- Non-zero coefficients\n", |
| " m.intercept, -- Intercept\n", |
| " ARRAY[tax,size] -- Features corresponding to non-zero coefficients\n", |
| " ) AS predict\n", |
| " FROM houses, houses_en2 m) s\n", |
| "ORDER BY id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "## 6. Cross validation\n", |
| "Reuse the houses table above. Here we use 3-fold cross validation with 3 automatically generated lambda values and 3 specified alpha values. (This can take some time to run since elastic net is effectively being called 27 times for these combinations, then a 28th time for the whole dataset.)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n", |
| "SELECT madlib.elastic_net_train( 'houses', -- Source table\n", |
| " 'houses_en3', -- Result table\n", |
| " 'price', -- Dependent variable\n", |
| " 'array[tax, bath, size]', -- Independent variable\n", |
| " 'gaussian', -- Regression family\n", |
| " 0.5, -- Alpha value\n", |
| " 0.1, -- Lambda value\n", |
| " TRUE, -- Standardize\n", |
| " NULL, -- Grouping column(s)\n", |
| " 'fista', -- Optimizer\n", |
| " $$ n_folds = 3, -- Optimizer parameters\n", |
| " validation_result=houses_en3_cv,\n", |
| " n_lambdas = 3, \n", |
| " alpha = {0, 0.1, 1}\n", |
| " $$, \n", |
| " NULL, -- Excluded columns\n", |
| " 10000, -- Maximum iterations\n", |
| " 1e-6 -- Tolerance value\n", |
| " );\n", |
| "SELECT * FROM houses_en3;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Details of the cross validation:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "9 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>alpha</th>\n", |
| " <th>lambda_value</th>\n", |
| " <th>mean_neg_loss</th>\n", |
| " <th>std_neg_loss</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.0</td>\n", |
| " <td>0.1</td>\n", |
| " <td>-36094.4685768</td>\n", |
| " <td>10524.4473253</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.1</td>\n", |
| " <td>0.1</td>\n", |
| " <td>-36136.2448004</td>\n", |
| " <td>10682.4136993</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1.0</td>\n", |
| " <td>100.0</td>\n", |
| " <td>-37007.9496501</td>\n", |
| " <td>12679.3781975</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1.0</td>\n", |
| " <td>0.1</td>\n", |
| " <td>-37018.1019927</td>\n", |
| " <td>12716.7438015</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.1</td>\n", |
| " <td>100.0</td>\n", |
| " <td>-59275.6940173</td>\n", |
| " <td>9764.50064237</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.0</td>\n", |
| " <td>100.0</td>\n", |
| " <td>-59380.252681</td>\n", |
| " <td>9763.26373034</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1.0</td>\n", |
| " <td>100000.0</td>\n", |
| " <td>-60353.0220769</td>\n", |
| " <td>9748.10305107</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.1</td>\n", |
| " <td>100000.0</td>\n", |
| " <td>-143513752113000000000000000000000000000000000000000000</td>\n", |
| " <td>157073834312000000000000000000000000000000000000000000</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.0</td>\n", |
| " <td>100000.0</td>\n", |
| " <td>-11248884473800000000000000000000000000000000000000000000</td>\n", |
| " <td>9490568229990000000000000000000000000000000000000000000</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(Decimal('0.0'), Decimal('0.1'), Decimal('-36094.4685768'), Decimal('10524.4473253')),\n", |
| " (Decimal('0.1'), Decimal('0.1'), Decimal('-36136.2448004'), Decimal('10682.4136993')),\n", |
| " (Decimal('1.0'), Decimal('100.0'), Decimal('-37007.9496501'), Decimal('12679.3781975')),\n", |
| " (Decimal('1.0'), Decimal('0.1'), Decimal('-37018.1019927'), Decimal('12716.7438015')),\n", |
| " (Decimal('0.1'), Decimal('100.0'), Decimal('-59275.6940173'), Decimal('9764.50064237')),\n", |
| " (Decimal('0.0'), Decimal('100.0'), Decimal('-59380.252681'), Decimal('9763.26373034')),\n", |
| " (Decimal('1.0'), Decimal('100000.0'), Decimal('-60353.0220769'), Decimal('9748.10305107')),\n", |
| " (Decimal('0.1'), Decimal('100000.0'), Decimal('-143513752113000000000000000000000000000000000000000000'), Decimal('157073834312000000000000000000000000000000000000000000')),\n", |
| " (Decimal('0.0'), Decimal('100000.0'), Decimal('-11248884473800000000000000000000000000000000000000000000'), Decimal('9490568229990000000000000000000000000000000000000000000'))]" |
| ] |
| }, |
| "execution_count": 9, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>method</th>\n", |
| " <th>source_table</th>\n", |
| " <th>out_table</th>\n", |
| " <th>dependent_varname</th>\n", |
| " <th>independent_varname</th>\n", |
| " <th>family</th>\n", |
| " <th>alpha</th>\n", |
| " <th>lambda_value</th>\n", |
| " <th>grouping_col</th>\n", |
| " <th>num_all_groups</th>\n", |
| " <th>num_failed_groups</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>elastic_net</td>\n", |
| " <td>houses</td>\n", |
| " <td>houses_en3</td>\n", |
| " <td>price</td>\n", |
| " <td>array[tax, bath, size]</td>\n", |
| " <td>gaussian</td>\n", |
| " <td>0.0</td>\n", |
| " <td>0.1</td>\n", |
| " <td>NULL</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'elastic_net', u'houses', u'houses_en3', u'price', u'array[tax, bath, size]', u'gaussian', 0.0, 0.1, u'NULL', 1, 0)]" |
| ] |
| }, |
| "execution_count": 12, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM houses_en3_summary;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "collapsed": true |
| }, |
| "source": [ |
| "# 6a. Cross validation\n", |
| "Here we use 3-fold cross validation with 3 automatically generated lambda values and 1 alpha value (i.e., 9 times)." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n", |
| "SELECT madlib.elastic_net_train( 'houses', -- Source table\n", |
| " 'houses_en3', -- Result table\n", |
| " 'price', -- Dependent variable\n", |
| " 'array[tax, bath, size]', -- Independent variable\n", |
| " 'gaussian', -- Regression family\n", |
| " 0.5, -- Alpha value\n", |
| " 0.1, -- Lambda value\n", |
| " TRUE, -- Standardize\n", |
| " NULL, -- Grouping column(s)\n", |
| " 'fista', -- Optimizer\n", |
| " $$ n_folds = 3, -- Optimizer parameters\n", |
| " validation_result=houses_en3_cv,\n", |
| " n_lambdas = 3\n", |
| " $$, \n", |
| " NULL, -- Excluded columns\n", |
| " 10000, -- Maximum iterations\n", |
| " 1e-6 -- Tolerance value\n", |
| " );\n", |
| "SELECT * FROM houses_en3;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Details of the cross validation:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "3 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>lambda_value</th>\n", |
| " <th>mean_neg_loss</th>\n", |
| " <th>std_neg_loss</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>100000.0</td>\n", |
| " <td>-255543791799000000000000000000000000000000000</td>\n", |
| " <td>442158712729000000000000000000000000000000000</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>100.0</td>\n", |
| " <td>-59332.2198813</td>\n", |
| " <td>8220.8755071</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.1</td>\n", |
| " <td>-51938.9613421</td>\n", |
| " <td>28946.523247</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(Decimal('100000.0'), Decimal('-255543791799000000000000000000000000000000000'), Decimal('442158712729000000000000000000000000000000000')),\n", |
| " (Decimal('100.0'), Decimal('-59332.2198813'), Decimal('8220.8755071')),\n", |
| " (Decimal('0.1'), Decimal('-51938.9613421'), Decimal('28946.523247'))]" |
| ] |
| }, |
| "execution_count": 16, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "collapsed": true |
| }, |
| "outputs": [], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM houses_en3_summary;" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 2", |
| "language": "python", |
| "name": "python2" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 2 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython2", |
| "version": "2.7.12" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 1 |
| } |