blob: 7592fe66072aec7c7ef0fa1eecb2916c73d4fca9 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Elastic net (MADlib v1.10+)\n",
"Demonstrates elastic net, including these updates:\n",
"- in MADlib 1.10: grouping and cross validation introduced \n",
"- in MADlib 1.13: report negative root mean squared error instead of the negative mean squared error"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The sql extension is already loaded. To reload it, use:\n",
" %reload_ext sql\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.4.0 on GCP (demo machine)\n",
"%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-gabafa66, cmake configuration time: Wed Jul 11 00:36:05 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-23-gabafa66, cmake configuration time: Wed Jul 11 00:36:05 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Create data set\n",
"House prices and characteristics."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"27 rows affected.\n",
"27 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>zipcode</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>94301</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>770</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>91000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>1220</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>132300</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>1150</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>91100</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>2690</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260011</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>780</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>141800</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>1910</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160900</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>3600</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>239000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>1600</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>81010</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>1590</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>117910</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>3200</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>141100</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>2270</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148011</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>750</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>66000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>76010</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100, 94301),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000, 94301),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500, 94301),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500, 94301),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000, 94301),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700, 94301),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000, 94301),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000, 94301),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000, 94301),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000, 94301),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500, 94301),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000, 94301),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000, 94301),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000, 94301),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000, 94301),\n",
" (16, 770, 2, 2.0, 91000, 1300, 17500, 76010),\n",
" (17, 1220, 3, 2.0, 132300, 1500, 30000, 76010),\n",
" (18, 1150, 2, 1.0, 91100, 820, 25700, 76010),\n",
" (19, 2690, 3, 2.5, 260011, 2130, 25000, 76010),\n",
" (20, 780, 2, 1.0, 141800, 1170, 22000, 76010),\n",
" (21, 1910, 3, 2.0, 160900, 1500, 19000, 76010),\n",
" (22, 3600, 4, 2.0, 239000, 2790, 20000, 76010),\n",
" (23, 1600, 3, 1.0, 81010, 1030, 17500, 76010),\n",
" (24, 1590, 3, 2.0, 117910, 1250, 20000, 76010),\n",
" (25, 3200, 3, 2.0, 141100, 1760, 38000, 76010),\n",
" (26, 2270, 2, 3.0, 148011, 1550, 14000, 76010),\n",
" (27, 750, 3, 1.5, 66000, 1450, 12000, 76010)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS houses;\n",
"\n",
"CREATE TABLE houses ( id INT,\n",
" tax INT,\n",
" bedroom INT,\n",
" bath FLOAT,\n",
" price INT,\n",
" size INT,\n",
" lot INT,\n",
" zipcode INT);\n",
"\n",
"INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode) VALUES\n",
"(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301),\n",
"(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301),\n",
"(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301),\n",
"(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301),\n",
"(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301),\n",
"(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301),\n",
"(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301),\n",
"(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301),\n",
"(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301),\n",
"(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301),\n",
"(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301),\n",
"(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301),\n",
"(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301),\n",
"(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301),\n",
"(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301),\n",
"(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010),\n",
"(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010),\n",
"(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010),\n",
"(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010),\n",
"(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010),\n",
"(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010),\n",
"(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010),\n",
"(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010),\n",
"(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010),\n",
"(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010),\n",
"(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010),\n",
"(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010);\n",
"\n",
"SELECT * FROM houses ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Train the model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>family</th>\n",
" <th>features</th>\n",
" <th>features_selected</th>\n",
" <th>coef_nonzero</th>\n",
" <th>coef_all</th>\n",
" <th>intercept</th>\n",
" <th>log_likelihood</th>\n",
" <th>standardize</th>\n",
" <th>iteration_run</th>\n",
" </tr>\n",
" <tr>\n",
" <td>gaussian</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n",
" <td>[22.7851318679, 10707.9553682, 54.7961166559]</td>\n",
" <td>-7798.78310728</td>\n",
" <td>-512248641.97</td>\n",
" <td>True</td>\n",
" <td>10000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [22.7851318679, 10707.9553682, 54.7961166559], [22.7851318679, 10707.9553682, 54.7961166559], -7798.78310728, -512248641.97, True, 10000)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_en, houses_en_summary;\n",
"SELECT madlib.elastic_net_train( 'houses', -- Source table\n",
" 'houses_en', -- Result table\n",
" 'price', -- Dependent variable\n",
" 'array[tax, bath, size]', -- Independent variable\n",
" 'gaussian', -- Regression family\n",
" 0.5, -- Alpha value\n",
" 0.1, -- Lambda value\n",
" TRUE, -- Standardize\n",
" NULL, -- Grouping column(s)\n",
" 'fista', -- Optimizer\n",
" '', -- Optimizer parameters\n",
" NULL, -- Excluded columns\n",
" 10000, -- Maximum iterations\n",
" 1e-6 -- Tolerance value\n",
" );\n",
"SELECT * FROM houses_en;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Prediction\n",
"Evaluate residuals."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>predict</th>\n",
" <th>residual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>50000</td>\n",
" <td>58545.409888</td>\n",
" <td>-8545.40988802</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>85000</td>\n",
" <td>114804.040575</td>\n",
" <td>-29804.0405752</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>22500</td>\n",
" <td>61448.7585535</td>\n",
" <td>-38948.7585535</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>90000</td>\n",
" <td>104675.144007</td>\n",
" <td>-14675.1440069</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>133000</td>\n",
" <td>125887.676679</td>\n",
" <td>7112.3233214</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>90500</td>\n",
" <td>78601.9159404</td>\n",
" <td>11898.0840596</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>260000</td>\n",
" <td>199257.351702</td>\n",
" <td>60742.6482983</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>142500</td>\n",
" <td>82514.5184185</td>\n",
" <td>59985.4815815</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>160000</td>\n",
" <td>137735.94525</td>\n",
" <td>22264.0547501</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>240000</td>\n",
" <td>250347.578373</td>\n",
" <td>-10347.578373</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>87000</td>\n",
" <td>97172.4913172</td>\n",
" <td>-10172.4913172</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>118600</td>\n",
" <td>119024.187075</td>\n",
" <td>-424.187074993</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>140000</td>\n",
" <td>180692.201734</td>\n",
" <td>-40692.201734</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>148000</td>\n",
" <td>156424.286781</td>\n",
" <td>-8424.28678052</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>65000</td>\n",
" <td>102527.85481</td>\n",
" <td>-37527.8548102</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>91000</td>\n",
" <td>102396.63082</td>\n",
" <td>-11396.6308201</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>132300</td>\n",
" <td>123609.163492</td>\n",
" <td>8690.83650819</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>91100</td>\n",
" <td>74044.8895668</td>\n",
" <td>17055.1104332</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>260011</td>\n",
" <td>196978.838515</td>\n",
" <td>63032.1614851</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>141800</td>\n",
" <td>84793.0316053</td>\n",
" <td>57006.9683947</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>160900</td>\n",
" <td>139330.904481</td>\n",
" <td>21569.0955193</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>239000</td>\n",
" <td>248524.767824</td>\n",
" <td>-9524.76782352</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>81010</td>\n",
" <td>95805.3834051</td>\n",
" <td>-14795.3834051</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>117910</td>\n",
" <td>118340.633119</td>\n",
" <td>-430.633118956</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>141100</td>\n",
" <td>182970.714921</td>\n",
" <td>-41870.7149208</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>148011</td>\n",
" <td>160981.313154</td>\n",
" <td>-12970.3131541</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>66000</td>\n",
" <td>104806.367997</td>\n",
" <td>-38806.367997</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 50000, 58545.409888024, -8545.409888024),\n",
" (2, 85000, 114804.040575234, -29804.040575234),\n",
" (3, 22500, 61448.758553532, -38948.758553532),\n",
" (4, 90000, 104675.144006863, -14675.144006863),\n",
" (5, 133000, 125887.676678598, 7112.323321402),\n",
" (6, 90500, 78601.915940423, 11898.084059577),\n",
" (7, 260000, 199257.351701728, 60742.648298272),\n",
" (8, 142500, 82514.518418495, 59985.481581505),\n",
" (9, 160000, 137735.945249906, 22264.054750094),\n",
" (10, 240000, 250347.578372953, -10347.578372953),\n",
" (11, 87000, 97172.491317211, -10172.491317211),\n",
" (12, 118600, 119024.187074993, -424.187074992995),\n",
" (13, 140000, 180692.201733994, -40692.201733994),\n",
" (14, 148000, 156424.286780518, -8424.28678051801),\n",
" (15, 65000, 102527.85481021, -37527.85481021),\n",
" (16, 91000, 102396.630820073, -11396.630820073),\n",
" (17, 132300, 123609.163491808, 8690.83650819201),\n",
" (18, 91100, 74044.889566843, 17055.110433157),\n",
" (19, 260011, 196978.838514938, 63032.161485062),\n",
" (20, 141800, 84793.031605285, 57006.968394715),\n",
" (21, 160900, 139330.904480659, 21569.095519341),\n",
" (22, 239000, 248524.767823521, -9524.76782352099),\n",
" (23, 81010, 95805.383405137, -14795.383405137),\n",
" (24, 117910, 118340.633118956, -430.633118956001),\n",
" (25, 141100, 182970.714920784, -41870.714920784),\n",
" (26, 148011, 160981.313154098, -12970.313154098),\n",
" (27, 66000, 104806.367997, -38806.367997)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT id, price, predict, price - predict AS residual\n",
"FROM (\n",
" SELECT\n",
" houses.*,\n",
" madlib.elastic_net_gaussian_predict(\n",
" m.coef_all, -- Coefficients\n",
" m.intercept, -- Intercept\n",
" ARRAY[tax,bath,size] -- Features (corresponding to coefficients)\n",
" ) AS predict\n",
" FROM houses, houses_en m) s\n",
"ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Grouping \n",
"Group on zip code."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>zipcode</th>\n",
" <th>family</th>\n",
" <th>features</th>\n",
" <th>features_selected</th>\n",
" <th>coef_nonzero</th>\n",
" <th>coef_all</th>\n",
" <th>intercept</th>\n",
" <th>log_likelihood</th>\n",
" <th>standardize</th>\n",
" <th>iteration_run</th>\n",
" </tr>\n",
" <tr>\n",
" <td>76010</td>\n",
" <td>gaussian</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n",
" <td>[14.9802020928, 9133.17041265, 62.8225614522]</td>\n",
" <td>14.7294468096</td>\n",
" <td>-525667117.987</td>\n",
" <td>True</td>\n",
" <td>10000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94301</td>\n",
" <td>gaussian</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n",
" <td>[27.6945649037, 11509.010807, 49.0945476263]</td>\n",
" <td>-11145.5017384</td>\n",
" <td>-520358795.785</td>\n",
" <td>True</td>\n",
" <td>10000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(76010, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [14.9802020928, 9133.17041265, 62.8225614522], [14.9802020928, 9133.17041265, 62.8225614522], 14.7294468096, -525667117.987, True, 10000),\n",
" (94301, u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'bath', u'size'], [27.6945649037, 11509.010807, 49.0945476263], [27.6945649037, 11509.010807, 49.0945476263], -11145.5017384, -520358795.785, True, 10000)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_en1, houses_en1_summary;\n",
"SELECT madlib.elastic_net_train( 'houses', -- Source table\n",
" 'houses_en1', -- Result table\n",
" 'price', -- Dependent variable\n",
" 'array[tax, bath, size]', -- Independent variable\n",
" 'gaussian', -- Regression family\n",
" 0.5, -- Alpha value\n",
" 0.1, -- Lambda value\n",
" TRUE, -- Standardize\n",
" 'zipcode', -- Grouping column(s)\n",
" 'fista', -- Optimizer\n",
" '', -- Optimizer parameters\n",
" NULL, -- Excluded columns\n",
" 10000, -- Maximum iterations\n",
" 1e-6 -- Tolerance value\n",
" );\n",
"SELECT * FROM houses_en1;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prediction function"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"27 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>prediction</th>\n",
" <th>residual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>50000</td>\n",
" <td>54506.104034</td>\n",
" <td>-4506.10403403</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>85000</td>\n",
" <td>110175.125178</td>\n",
" <td>-25175.1251776</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>22500</td>\n",
" <td>52957.6208506</td>\n",
" <td>-30457.6208506</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>90000</td>\n",
" <td>99789.703256</td>\n",
" <td>-9789.70325601</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>133000</td>\n",
" <td>122071.166988</td>\n",
" <td>10928.8330121</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>90500</td>\n",
" <td>78008.7007422</td>\n",
" <td>12491.2992578</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>260000</td>\n",
" <td>199466.247804</td>\n",
" <td>60533.7521956</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>142500</td>\n",
" <td>76636.4339259</td>\n",
" <td>65863.5660741</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>160000</td>\n",
" <td>136472.340738</td>\n",
" <td>23527.6592621</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>240000</td>\n",
" <td>250762.306599</td>\n",
" <td>-10762.3065986</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>87000</td>\n",
" <td>96903.8708638</td>\n",
" <td>-9903.87086383</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>118600</td>\n",
" <td>118105.899552</td>\n",
" <td>494.100447531</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>140000</td>\n",
" <td>184132.074899</td>\n",
" <td>-44132.0748994</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>148000</td>\n",
" <td>156805.828854</td>\n",
" <td>-8805.82885402</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>65000</td>\n",
" <td>95306.5757176</td>\n",
" <td>-30306.5757176</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>91000</td>\n",
" <td>111485.155771</td>\n",
" <td>-20485.1557714</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>132300</td>\n",
" <td>130790.759004</td>\n",
" <td>1509.24099637</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>91100</td>\n",
" <td>77889.632657</td>\n",
" <td>13210.367343</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>260011</td>\n",
" <td>196956.455001</td>\n",
" <td>63054.5449987</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>141800</td>\n",
" <td>94334.8543909</td>\n",
" <td>47465.1456091</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>160900</td>\n",
" <td>141127.098448</td>\n",
" <td>19772.9015523</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>239000</td>\n",
" <td>247484.744258</td>\n",
" <td>-8484.74425783</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>81010</td>\n",
" <td>97823.4615037</td>\n",
" <td>-16813.4615037</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>117910</td>\n",
" <td>120627.793415</td>\n",
" <td>-2717.79341491</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>141100</td>\n",
" <td>176785.425125</td>\n",
" <td>-35685.4251249</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>148011</td>\n",
" <td>158794.269686</td>\n",
" <td>-10783.2696863</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>66000</td>\n",
" <td>116042.350741</td>\n",
" <td>-50042.3507411</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 50000, 54506.104034034, -4506.104034034),\n",
" (2, 85000, 110175.125177568, -25175.125177568),\n",
" (3, 22500, 52957.620850552, -30457.620850552),\n",
" (4, 90000, 99789.703256009, -9789.703256009),\n",
" (5, 133000, 122071.166987934, 10928.833012066),\n",
" (6, 90500, 78008.700742161, 12491.299257839),\n",
" (7, 260000, 199466.247804442, 60533.752195558),\n",
" (8, 142500, 76636.433925887, 65863.566074113),\n",
" (9, 160000, 136472.340737858, 23527.659262142),\n",
" (10, 240000, 250762.306598593, -10762.306598593),\n",
" (11, 87000, 96903.870863831, -9903.87086383101),\n",
" (12, 118600, 118105.899552469, 494.100447531004),\n",
" (13, 140000, 184132.074899358, -44132.074899358),\n",
" (14, 148000, 156805.828854024, -8805.828854024),\n",
" (15, 65000, 95306.57571764, -30306.57571764),\n",
" (16, 91000, 111485.155771426, -20485.1557714256),\n",
" (17, 132300, 130790.759003626, 1509.2409963744),\n",
" (18, 91100, 77889.6326569836, 13210.3673430164),\n",
" (19, 260011, 196956.455001253, 63054.5449987474),\n",
" (20, 141800, 94334.8543909176, 47465.1456090824),\n",
" (21, 160900, 141127.098447658, 19772.9015523424),\n",
" (22, 239000, 247484.744257828, -8484.74425782761),\n",
" (23, 81010, 97823.4615037056, -16813.4615037056),\n",
" (24, 117910, 120627.793414912, -2717.7934149116),\n",
" (25, 141100, 176785.425124942, -35685.4251249416),\n",
" (26, 148011, 158794.269686326, -10783.2696863256),\n",
" (27, 66000, 116042.350741075, -50042.3507410746)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT madlib.elastic_net_predict(\n",
" 'houses_en1', -- Model table\n",
" 'houses', -- New source data table\n",
" 'id', -- Unique ID associated with each row\n",
" 'houses_en1_prediction' -- Table to store prediction result\n",
" );\n",
"\n",
"SELECT houses.id,\n",
" houses.price,\n",
" houses_en1_prediction.prediction,\n",
" houses.price - houses_en1_prediction.prediction AS residual\n",
"FROM houses_en1_prediction, houses\n",
"WHERE houses.id = houses_en1_prediction.id ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. When coef_nonzero is different from coef_all\n",
"Train"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>family</th>\n",
" <th>features</th>\n",
" <th>features_selected</th>\n",
" <th>coef_nonzero</th>\n",
" <th>coef_all</th>\n",
" <th>intercept</th>\n",
" <th>log_likelihood</th>\n",
" <th>standardize</th>\n",
" <th>iteration_run</th>\n",
" </tr>\n",
" <tr>\n",
" <td>gaussian</td>\n",
" <td>[u'tax', u'bath', u'size']</td>\n",
" <td>[u'tax', u'size']</td>\n",
" <td>[6.94383308191, 29.7206857861]</td>\n",
" <td>[6.94383308191, 0.0, 29.7206857861]</td>\n",
" <td>74441.4573381</td>\n",
" <td>-1635348584.1</td>\n",
" <td>True</td>\n",
" <td>173</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'gaussian', [u'tax', u'bath', u'size'], [u'tax', u'size'], [6.94383308191, 29.7206857861], [6.94383308191, 0.0, 29.7206857861], 74441.4573381, -1635348584.1, True, 173)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_en2, houses_en2_summary;\n",
"SELECT madlib.elastic_net_train( 'houses', -- Source table\n",
" 'houses_en2', -- Result table\n",
" 'price', -- Dependent variable\n",
" 'array[tax, bath, size]', -- Independent variable\n",
" 'gaussian', -- Regression family\n",
" 1, -- Alpha value\n",
" 30000, -- Lambda value\n",
" TRUE, -- Standardize\n",
" NULL, -- Grouping column(s)\n",
" 'fista', -- Optimizer\n",
" '', -- Optimizer parameters\n",
" NULL, -- Excluded columns\n",
" 10000, -- Maximum iterations\n",
" 1e-6 -- Tolerance value\n",
" );\n",
"SELECT * FROM houses_en2;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prediction function with coef_all to evaluate residuals."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>predict</th>\n",
" <th>residual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>50000</td>\n",
" <td>101423.246912</td>\n",
" <td>-51423.2469117</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>85000</td>\n",
" <td>123638.649033</td>\n",
" <td>-38638.6490325</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>22500</td>\n",
" <td>106084.260933</td>\n",
" <td>-83584.260933</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>90000</td>\n",
" <td>119119.483641</td>\n",
" <td>-29119.4836413</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>133000</td>\n",
" <td>128188.345685</td>\n",
" <td>4811.65431463</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>90500</td>\n",
" <td>108186.594343</td>\n",
" <td>-17686.5943433</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>260000</td>\n",
" <td>157119.812361</td>\n",
" <td>102880.187639</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>142500</td>\n",
" <td>113936.466204</td>\n",
" <td>28563.5337965</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>160000</td>\n",
" <td>131799.138888</td>\n",
" <td>28200.861112</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>240000</td>\n",
" <td>182915.476423</td>\n",
" <td>57084.5235773</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>87000</td>\n",
" <td>116580.526614</td>\n",
" <td>-29580.5266138</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>118600</td>\n",
" <td>122841.324163</td>\n",
" <td>-4241.32416342</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>140000</td>\n",
" <td>148275.746876</td>\n",
" <td>-8275.74687556</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>148000</td>\n",
" <td>134882.254786</td>\n",
" <td>13117.7452139</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>65000</td>\n",
" <td>122049.943231</td>\n",
" <td>-57049.9432312</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>91000</td>\n",
" <td>118425.100333</td>\n",
" <td>-27425.1003331</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>132300</td>\n",
" <td>127493.962377</td>\n",
" <td>4806.03762282</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>91100</td>\n",
" <td>106797.827727</td>\n",
" <td>-15697.8277269</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>260011</td>\n",
" <td>156425.429053</td>\n",
" <td>103585.570947</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>141800</td>\n",
" <td>114630.849512</td>\n",
" <td>27169.1504883</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>160900</td>\n",
" <td>132285.207204</td>\n",
" <td>28614.7927963</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>239000</td>\n",
" <td>182359.969776</td>\n",
" <td>56640.0302238</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>81010</td>\n",
" <td>116163.896629</td>\n",
" <td>-35153.8966288</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>117910</td>\n",
" <td>122633.009171</td>\n",
" <td>-4723.00917096</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>141100</td>\n",
" <td>148970.130184</td>\n",
" <td>-7870.13018375</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>148011</td>\n",
" <td>136271.021402</td>\n",
" <td>11739.9785975</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>66000</td>\n",
" <td>122744.326539</td>\n",
" <td>-56744.3265394</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 50000, 101423.246911724, -51423.2469117239),\n",
" (2, 85000, 123638.649032506, -38638.6490325065),\n",
" (3, 22500, 106084.260933004, -83584.2609330042),\n",
" (4, 90000, 119119.483641292, -29119.4836412917),\n",
" (5, 133000, 128188.345685371, 4811.6543146288),\n",
" (6, 90500, 108186.59434328, -17686.5943432805),\n",
" (7, 260000, 157119.812361022, 102880.187638978),\n",
" (8, 142500, 113936.466203536, 28563.5337964642),\n",
" (9, 160000, 131799.138887964, 28200.8611120356),\n",
" (10, 240000, 182915.476422748, 57084.5235772522),\n",
" (11, 87000, 116580.526613754, -29580.5266137536),\n",
" (12, 118600, 122841.324163419, -4241.32416341919),\n",
" (13, 140000, 148275.746875557, -8275.746875557),\n",
" (14, 148000, 134882.254786109, 13117.7452138913),\n",
" (15, 65000, 122049.943231186, -57049.9432311865),\n",
" (16, 91000, 118425.100333101, -27425.1003331007),\n",
" (17, 132300, 127493.96237718, 4806.03762281981),\n",
" (18, 91100, 106797.827726898, -15697.8277268985),\n",
" (19, 260011, 156425.429052831, 103585.570947169),\n",
" (20, 141800, 114630.849511727, 27169.1504882732),\n",
" (21, 160900, 132285.207203698, 28614.7927963019),\n",
" (22, 239000, 182359.969776195, 56640.030223805),\n",
" (23, 81010, 116163.896628839, -35153.896628839),\n",
" (24, 117910, 122633.009170962, -4723.00917096189),\n",
" (25, 141100, 148970.130183748, -7870.130183748),\n",
" (26, 148011, 136271.021402491, 11739.9785975093),\n",
" (27, 66000, 122744.326539377, -56744.3265393775)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT id, price, predict, price - predict AS residual\n",
"FROM (\n",
" SELECT\n",
" houses.*,\n",
" madlib.elastic_net_gaussian_predict(\n",
" m.coef_all, -- All coefficients\n",
" m.intercept, -- Intercept\n",
" ARRAY[tax,bath,size] -- All features\n",
" ) AS predict\n",
" FROM houses, houses_en2 m) s\n",
"ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can speed up the prediction function with coef_nonzero to evaluate residuals. This requires the user to examine the feature_selected column in the result table to construct the correct set of independent variables to provide to the prediction function."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>predict</th>\n",
" <th>residual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>50000</td>\n",
" <td>101423.246912</td>\n",
" <td>-51423.2469117</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>85000</td>\n",
" <td>123638.649033</td>\n",
" <td>-38638.6490325</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>22500</td>\n",
" <td>106084.260933</td>\n",
" <td>-83584.260933</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>90000</td>\n",
" <td>119119.483641</td>\n",
" <td>-29119.4836413</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>133000</td>\n",
" <td>128188.345685</td>\n",
" <td>4811.65431463</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>90500</td>\n",
" <td>108186.594343</td>\n",
" <td>-17686.5943433</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>260000</td>\n",
" <td>157119.812361</td>\n",
" <td>102880.187639</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>142500</td>\n",
" <td>113936.466204</td>\n",
" <td>28563.5337965</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>160000</td>\n",
" <td>131799.138888</td>\n",
" <td>28200.861112</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>240000</td>\n",
" <td>182915.476423</td>\n",
" <td>57084.5235773</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>87000</td>\n",
" <td>116580.526614</td>\n",
" <td>-29580.5266138</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>118600</td>\n",
" <td>122841.324163</td>\n",
" <td>-4241.32416342</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>140000</td>\n",
" <td>148275.746876</td>\n",
" <td>-8275.74687556</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>148000</td>\n",
" <td>134882.254786</td>\n",
" <td>13117.7452139</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>65000</td>\n",
" <td>122049.943231</td>\n",
" <td>-57049.9432312</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>91000</td>\n",
" <td>118425.100333</td>\n",
" <td>-27425.1003331</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>132300</td>\n",
" <td>127493.962377</td>\n",
" <td>4806.03762282</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>91100</td>\n",
" <td>106797.827727</td>\n",
" <td>-15697.8277269</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>260011</td>\n",
" <td>156425.429053</td>\n",
" <td>103585.570947</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>141800</td>\n",
" <td>114630.849512</td>\n",
" <td>27169.1504883</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>160900</td>\n",
" <td>132285.207204</td>\n",
" <td>28614.7927963</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>239000</td>\n",
" <td>182359.969776</td>\n",
" <td>56640.0302238</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>81010</td>\n",
" <td>116163.896629</td>\n",
" <td>-35153.8966288</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>117910</td>\n",
" <td>122633.009171</td>\n",
" <td>-4723.00917096</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>141100</td>\n",
" <td>148970.130184</td>\n",
" <td>-7870.13018375</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>148011</td>\n",
" <td>136271.021402</td>\n",
" <td>11739.9785975</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>66000</td>\n",
" <td>122744.326539</td>\n",
" <td>-56744.3265394</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 50000, 101423.246911724, -51423.2469117239),\n",
" (2, 85000, 123638.649032506, -38638.6490325065),\n",
" (3, 22500, 106084.260933004, -83584.2609330042),\n",
" (4, 90000, 119119.483641292, -29119.4836412917),\n",
" (5, 133000, 128188.345685371, 4811.6543146288),\n",
" (6, 90500, 108186.59434328, -17686.5943432805),\n",
" (7, 260000, 157119.812361022, 102880.187638978),\n",
" (8, 142500, 113936.466203536, 28563.5337964642),\n",
" (9, 160000, 131799.138887964, 28200.8611120356),\n",
" (10, 240000, 182915.476422748, 57084.5235772522),\n",
" (11, 87000, 116580.526613754, -29580.5266137536),\n",
" (12, 118600, 122841.324163419, -4241.32416341919),\n",
" (13, 140000, 148275.746875557, -8275.746875557),\n",
" (14, 148000, 134882.254786109, 13117.7452138913),\n",
" (15, 65000, 122049.943231186, -57049.9432311865),\n",
" (16, 91000, 118425.100333101, -27425.1003331007),\n",
" (17, 132300, 127493.96237718, 4806.03762281981),\n",
" (18, 91100, 106797.827726898, -15697.8277268985),\n",
" (19, 260011, 156425.429052831, 103585.570947169),\n",
" (20, 141800, 114630.849511727, 27169.1504882732),\n",
" (21, 160900, 132285.207203698, 28614.7927963019),\n",
" (22, 239000, 182359.969776195, 56640.030223805),\n",
" (23, 81010, 116163.896628839, -35153.896628839),\n",
" (24, 117910, 122633.009170962, -4723.00917096189),\n",
" (25, 141100, 148970.130183748, -7870.130183748),\n",
" (26, 148011, 136271.021402491, 11739.9785975093),\n",
" (27, 66000, 122744.326539377, -56744.3265393775)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT id, price, predict, price - predict AS residual\n",
"FROM (\n",
" SELECT\n",
" houses.*,\n",
" madlib.elastic_net_gaussian_predict(\n",
" m.coef_nonzero, -- Non-zero coefficients\n",
" m.intercept, -- Intercept\n",
" ARRAY[tax,size] -- Features corresponding to non-zero coefficients\n",
" ) AS predict\n",
" FROM houses, houses_en2 m) s\n",
"ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Cross validation\n",
"Reuse the houses table above. Here we use 3-fold cross validation with 3 automatically generated lambda values and 3 specified alpha values. (This can take some time to run since elastic net is effectively being called 27 times for these combinations, then a 28th time for the whole dataset.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n",
"SELECT madlib.elastic_net_train( 'houses', -- Source table\n",
" 'houses_en3', -- Result table\n",
" 'price', -- Dependent variable\n",
" 'array[tax, bath, size]', -- Independent variable\n",
" 'gaussian', -- Regression family\n",
" 0.5, -- Alpha value\n",
" 0.1, -- Lambda value\n",
" TRUE, -- Standardize\n",
" NULL, -- Grouping column(s)\n",
" 'fista', -- Optimizer\n",
" $$ n_folds = 3, -- Optimizer parameters\n",
" validation_result=houses_en3_cv,\n",
" n_lambdas = 3, \n",
" alpha = {0, 0.1, 1}\n",
" $$, \n",
" NULL, -- Excluded columns\n",
" 10000, -- Maximum iterations\n",
" 1e-6 -- Tolerance value\n",
" );\n",
"SELECT * FROM houses_en3;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Details of the cross validation:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>alpha</th>\n",
" <th>lambda_value</th>\n",
" <th>mean_neg_loss</th>\n",
" <th>std_neg_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.0</td>\n",
" <td>0.1</td>\n",
" <td>-36094.4685768</td>\n",
" <td>10524.4473253</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>0.1</td>\n",
" <td>-36136.2448004</td>\n",
" <td>10682.4136993</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>100.0</td>\n",
" <td>-37007.9496501</td>\n",
" <td>12679.3781975</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.1</td>\n",
" <td>-37018.1019927</td>\n",
" <td>12716.7438015</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>100.0</td>\n",
" <td>-59275.6940173</td>\n",
" <td>9764.50064237</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.0</td>\n",
" <td>100.0</td>\n",
" <td>-59380.252681</td>\n",
" <td>9763.26373034</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>100000.0</td>\n",
" <td>-60353.0220769</td>\n",
" <td>9748.10305107</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>100000.0</td>\n",
" <td>-143513752113000000000000000000000000000000000000000000</td>\n",
" <td>157073834312000000000000000000000000000000000000000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.0</td>\n",
" <td>100000.0</td>\n",
" <td>-11248884473800000000000000000000000000000000000000000000</td>\n",
" <td>9490568229990000000000000000000000000000000000000000000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('0.0'), Decimal('0.1'), Decimal('-36094.4685768'), Decimal('10524.4473253')),\n",
" (Decimal('0.1'), Decimal('0.1'), Decimal('-36136.2448004'), Decimal('10682.4136993')),\n",
" (Decimal('1.0'), Decimal('100.0'), Decimal('-37007.9496501'), Decimal('12679.3781975')),\n",
" (Decimal('1.0'), Decimal('0.1'), Decimal('-37018.1019927'), Decimal('12716.7438015')),\n",
" (Decimal('0.1'), Decimal('100.0'), Decimal('-59275.6940173'), Decimal('9764.50064237')),\n",
" (Decimal('0.0'), Decimal('100.0'), Decimal('-59380.252681'), Decimal('9763.26373034')),\n",
" (Decimal('1.0'), Decimal('100000.0'), Decimal('-60353.0220769'), Decimal('9748.10305107')),\n",
" (Decimal('0.1'), Decimal('100000.0'), Decimal('-143513752113000000000000000000000000000000000000000000'), Decimal('157073834312000000000000000000000000000000000000000000')),\n",
" (Decimal('0.0'), Decimal('100000.0'), Decimal('-11248884473800000000000000000000000000000000000000000000'), Decimal('9490568229990000000000000000000000000000000000000000000'))]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>method</th>\n",
" <th>source_table</th>\n",
" <th>out_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>family</th>\n",
" <th>alpha</th>\n",
" <th>lambda_value</th>\n",
" <th>grouping_col</th>\n",
" <th>num_all_groups</th>\n",
" <th>num_failed_groups</th>\n",
" </tr>\n",
" <tr>\n",
" <td>elastic_net</td>\n",
" <td>houses</td>\n",
" <td>houses_en3</td>\n",
" <td>price</td>\n",
" <td>array[tax, bath, size]</td>\n",
" <td>gaussian</td>\n",
" <td>0.0</td>\n",
" <td>0.1</td>\n",
" <td>NULL</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'elastic_net', u'houses', u'houses_en3', u'price', u'array[tax, bath, size]', u'gaussian', 0.0, 0.1, u'NULL', 1, 0)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM houses_en3_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# 6a. Cross validation\n",
"Here we use 3-fold cross validation with 3 automatically generated lambda values and 1 alpha value (i.e., 9 times)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_en3, houses_en3_summary, houses_en3_cv;\n",
"SELECT madlib.elastic_net_train( 'houses', -- Source table\n",
" 'houses_en3', -- Result table\n",
" 'price', -- Dependent variable\n",
" 'array[tax, bath, size]', -- Independent variable\n",
" 'gaussian', -- Regression family\n",
" 0.5, -- Alpha value\n",
" 0.1, -- Lambda value\n",
" TRUE, -- Standardize\n",
" NULL, -- Grouping column(s)\n",
" 'fista', -- Optimizer\n",
" $$ n_folds = 3, -- Optimizer parameters\n",
" validation_result=houses_en3_cv,\n",
" n_lambdas = 3\n",
" $$, \n",
" NULL, -- Excluded columns\n",
" 10000, -- Maximum iterations\n",
" 1e-6 -- Tolerance value\n",
" );\n",
"SELECT * FROM houses_en3;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Details of the cross validation:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>lambda_value</th>\n",
" <th>mean_neg_loss</th>\n",
" <th>std_neg_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>100000.0</td>\n",
" <td>-255543791799000000000000000000000000000000000</td>\n",
" <td>442158712729000000000000000000000000000000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100.0</td>\n",
" <td>-59332.2198813</td>\n",
" <td>8220.8755071</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>-51938.9613421</td>\n",
" <td>28946.523247</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('100000.0'), Decimal('-255543791799000000000000000000000000000000000'), Decimal('442158712729000000000000000000000000000000000')),\n",
" (Decimal('100.0'), Decimal('-59332.2198813'), Decimal('8220.8755071')),\n",
" (Decimal('0.1'), Decimal('-51938.9613421'), Decimal('28946.523247'))]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM houses_en3_cv ORDER BY mean_neg_loss DESC;"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%%sql\n",
"SELECT * FROM houses_en3_summary;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}