blob: 405710d95997b285f342316f196660f2f8664b22 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Support Vector Machines\n",
"Support Vector Machines (SVMs) are models for regression and classification tasks. SVM models have two particularly desirable features: robustness in the presence of noisy data and applicability to a variety of data configurations. At its core, a linear SVM model is a hyperplane separating two distinct classes of data (in the case of classification problems), in such a way that the distance between the hyperplane and the nearest training data point (called the margin) is maximized. Vectors that lie on this margin are called support vectors. With the support vectors fixed, perturbations of vectors beyond the margin will not affect the model; this contributes to the model’s robustness. By substituting a kernel function for the usual inner product, one can approximate a large variety of decision boundaries in addition to linear hyperplanes."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The sql extension is already loaded. To reload it, use:\n",
" %reload_ext sql\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.4.0 on GCP (demo machine)\n",
"%sql postgresql://gpadmin@35.184.253.255:5432/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib\n",
"\n",
"# Greenplum Database 4.3.10.0\n",
"#%sql postgresql://gpdbchina@10.194.10.68:61000/madlib"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.15-dev, git revision: rc/1.14-rc1-25-gda13eb7, cmake configuration time: Tue Jul 10 21:37:52 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.15-dev, git revision: rc/1.14-rc1-25-gda13eb7, cmake configuration time: Tue Jul 10 21:37:52 UTC 2018, build type: release, build system: Linux-2.6.32-696.20.1.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Classification\n",
"# 1. Create input data set"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"15 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS houses;\n",
"\n",
"CREATE TABLE houses (id INT, tax INT, bedroom INT, bath FLOAT, price INT,\n",
" size INT, lot INT);\n",
"\n",
"INSERT INTO houses VALUES \n",
" (1 , 590 , 2 , 1 , 50000 , 770 , 22100),\n",
" (2 , 1050 , 3 , 2 , 85000 , 1410 , 12000),\n",
" (3 , 20 , 3 , 1 , 22500 , 1060 , 3500),\n",
" (4 , 870 , 2 , 2 , 90000 , 1300 , 17500),\n",
" (5 , 1320 , 3 , 2 , 133000 , 1500 , 30000),\n",
" (6 , 1350 , 2 , 1 , 90500 , 820 , 25700),\n",
" (7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000),\n",
" (8 , 680 , 2 , 1 , 142500 , 1170 , 22000),\n",
" (9 , 1840 , 3 , 2 , 160000 , 1500 , 19000),\n",
" (10 , 3680 , 4 , 2 , 240000 , 2790 , 20000),\n",
" (11 , 1660 , 3 , 1 , 87000 , 1030 , 17500),\n",
" (12 , 1620 , 3 , 2 , 118600 , 1250 , 20000),\n",
" (13 , 3100 , 3 , 2 , 140000 , 1760 , 38000),\n",
" (14 , 2070 , 2 , 3 , 148000 , 1550 , 14000),\n",
" (15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000);\n",
" \n",
"SELECT * FROM houses ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# 2. Train linear classification model\n",
"Categorical variable is price < $100,0000."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.124749754442359, -0.002823869432027, 0.0751780666986316, 0.00163774992345709]</td>\n",
" <td>0.647742474881</td>\n",
" <td>4412.03185101</td>\n",
" <td>100</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.124749754442359, -0.002823869432027, 0.0751780666986316, 0.00163774992345709], 0.647742474880954, 4412.03185100955, 100, 15L, 0L, [False, True])]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm, houses_svm_summary;\n",
"\n",
"SELECT madlib.svm_classification('houses',\n",
" 'houses_svm',\n",
" 'price < 100000',\n",
" 'ARRAY[1, tax, bath, size]'\n",
" );\n",
"SELECT * FROM houses_svm;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using linear model\n",
"We want to predict if house price is less than $100,000. We use the training data set for prediction as well, which is not usual but serves to show the syntax. The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" <th>actual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>False</td>\n",
" <td>-0.205087702693</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>False</td>\n",
" <td>-0.380729623714</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>True</td>\n",
" <td>1.87946535136</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>False</td>\n",
" <td>-0.0525856175296</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>False</td>\n",
" <td>-0.99577687725</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>False</td>\n",
" <td>-2.26934097486</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>False</td>\n",
" <td>-4.0774934572</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>True</td>\n",
" <td>0.195864017807</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>False</td>\n",
" <td>-2.4641889819</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-5.54741133557</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>False</td>\n",
" <td>-2.80081301486</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-2.25237518772</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>False</td>\n",
" <td>-5.59644948616</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>False</td>\n",
" <td>-2.9566133884</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>0.776739112686</td>\n",
" <td>True</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100, False, -0.205087702692976, True),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000, False, -0.380729623714223, True),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.87946535136497, True),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500, False, -0.0525856175296444, True),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -0.995776877250374, False),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700, False, -2.26934097486064, True),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -4.07749345720278, False),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000, True, 0.195864017807432, False),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -2.46418898190441, False),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -5.54741133557444, False),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500, False, -2.80081301486302, True),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -2.25237518772275, False),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -5.59644948615959, False),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -2.95661338839914, False),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000, True, 0.776739112685544, True)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_pred;\n",
"\n",
"SELECT madlib.svm_predict('houses_svm', \n",
" 'houses', \n",
" 'id', \n",
" 'houses_pred');\n",
"\n",
"SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count the miss-classifications:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6L,)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM houses_pred JOIN houses USING (id) \n",
"WHERE houses_pred.prediction != (houses.price < 100000);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Train using Gaussian kernel\n",
"Next generate a nonlinear model using a Gaussian kernel. This time we specify the initial step size and maximum number of iterations to run. As part of the kernel parameter, we choose 10 as the dimension of the space where we train SVM. A larger number will lead to a more powerful model but run the risk of overfitting. As a result, the model will be a 10 dimensional vector, instead of 4 as in the case of linear model."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[-1.67275666209207, 1.5191640881642, -0.503066422926726, 1.33250956564454, 2.23009854231314, -0.0602475029497933, 1.97466397155921, 2.3668779833279, 0.577739846910355, 2.81255996089823]</td>\n",
" <td>0.0571869097341</td>\n",
" <td>1.18281830047</td>\n",
" <td>177</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([-1.67275666209207, 1.5191640881642, -0.503066422926726, 1.33250956564454, 2.23009854231314, -0.0602475029497933, 1.97466397155921, 2.3668779833279, 0.577739846910355, 2.81255996089823], 0.0571869097340992, 1.18281830047046, 177, 15L, 0L, [False, True])]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n",
"\n",
"SELECT madlib.svm_classification( 'houses',\n",
" 'houses_svm_gaussian',\n",
" 'price < 100000',\n",
" 'ARRAY[1, tax, bath, size]',\n",
" 'gaussian',\n",
" 'n_components=10',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200'\n",
" );\n",
"\n",
"SELECT * FROM houses_svm_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Predict using Gaussian model\n",
"The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" <th>actual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>True</td>\n",
" <td>1.64923454025</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>1.34505433447</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>True</td>\n",
" <td>1.00000000092</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>True</td>\n",
" <td>1.00000000712</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>False</td>\n",
" <td>-1.00000001729</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>True</td>\n",
" <td>1.11113745879</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>False</td>\n",
" <td>-0.29148279088</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>False</td>\n",
" <td>-1.00000000609</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>False</td>\n",
" <td>-1.23665846847</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-1.0938201061</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>True</td>\n",
" <td>1.62636283239</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-1.60116812307</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>False</td>\n",
" <td>-1.09173031656</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>False</td>\n",
" <td>-3.16301875478</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>1.00000000486</td>\n",
" <td>True</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100, True, 1.64923454025379, True),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000, True, 1.34505433446611, True),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.0000000009249, True),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500, True, 1.00000000711647, True),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -1.00000001728685, False),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700, True, 1.11113745878827, True),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -0.291482790879796, False),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000, False, -1.00000000609094, False),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -1.23665846846941, False),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -1.09382010610257, False),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500, True, 1.62636283239171, True),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -1.6011681230749, False),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -1.09173031656082, False),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -3.16301875478316, False),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000, True, 1.00000000486389, True)]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_pred_gaussian;\n",
"\n",
"SELECT madlib.svm_predict('houses_svm_gaussian', \n",
" 'houses', \n",
" 'id', \n",
" 'houses_pred_gaussian');\n",
"\n",
"SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred_gaussian USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count the miss classifications. Note this produces a more accurate result than the linear case for this small data set:"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0L,)]"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM houses_pred_gaussian JOIN houses USING (id) \n",
"WHERE houses_pred_gaussian.prediction != (houses.price < 100000);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Balancing data sets\n",
"In the case of an unbalanced class-size dataset, use the 'balanced' parameter to classify when building the model:"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.891926151039837, 0.169282494673541, -2.26539133689874, 0.526518499596676, -0.900664505989526, 0.508112011288015, -0.355474591147659, 1.23127975981665, 1.53694964239487, 1.46496058633682]</td>\n",
" <td>0.569002744458</td>\n",
" <td>0.989597662459</td>\n",
" <td>183</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.891926151039837, 0.169282494673541, -2.26539133689874, 0.526518499596676, -0.900664505989526, 0.508112011288015, -0.355474591147659, 1.23127975981665, 1.53694964239487, 1.46496058633682], 0.56900274445785, 0.989597662458527, 183, 15L, 0L, [False, True])]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n",
"\n",
"SELECT madlib.svm_classification( 'houses',\n",
" 'houses_svm_gaussian',\n",
" 'price < 150000',\n",
" 'ARRAY[1, tax, bath, size]',\n",
" 'gaussian',\n",
" 'n_components=10',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200, class_weight=balanced'\n",
" );\n",
"\n",
"SELECT * FROM houses_svm_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Regression\n",
"# 1. Create input data set\n",
"For regression we use part of the well known abalone data set https://archive.ics.uci.edu/ml/datasets/abalone :"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"20 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone;\n",
"\n",
"CREATE TABLE abalone (id INT, sex TEXT, length FLOAT, diameter FLOAT, height FLOAT, rings INT);\n",
"\n",
"INSERT INTO abalone VALUES\n",
"(1,'M',0.455,0.365,0.095,15),\n",
"(2,'M',0.35,0.265,0.09,7),\n",
"(3,'F',0.53,0.42,0.135,9),\n",
"(4,'M',0.44,0.365,0.125,10),\n",
"(5,'I',0.33,0.255,0.08,7),\n",
"(6,'I',0.425,0.3,0.095,8),\n",
"(7,'F',0.53,0.415,0.15,20),\n",
"(8,'F',0.545,0.425,0.125,16),\n",
"(9,'M',0.475,0.37,0.125,9),\n",
"(10,'F',0.55,0.44,0.15,19),\n",
"(11,'F',0.525,0.38,0.14,14),\n",
"(12,'M',0.43,0.35,0.11,10),\n",
"(13,'M',0.49,0.38,0.135,11),\n",
"(14,'F',0.535,0.405,0.145,10),\n",
"(15,'F',0.47,0.355,0.1,10),\n",
"(16,'M',0.5,0.4,0.13,12),\n",
"(17,'I',0.355,0.28,0.085,7),\n",
"(18,'F',0.44,0.34,0.1,10),\n",
"(19,'M',0.365,0.295,0.08,7),\n",
"(20,'M',0.45,0.32,0.1,9);\n",
"\n",
"SELECT * FROM abalone ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Train linear regression model"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[1.998949892503, 0.918517478913099, 0.712125856084095, 0.229379472956877]</td>\n",
" <td>8.29033295818</td>\n",
" <td>23.2251777858</td>\n",
" <td>100</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([1.998949892503, 0.918517478913099, 0.712125856084095, 0.229379472956877], 8.29033295818392, 23.225177785827, 100, 20L, 0L, [None])]"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_regression, abalone_svm_regression_summary;\n",
"\n",
"SELECT madlib.svm_regression('abalone',\n",
" 'abalone_svm_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using linear model"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" <td>2.69859233281</td>\n",
" <td>2.69859233281</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" <td>2.52978851455</td>\n",
" <td>2.52978851455</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" <td>2.81582324473</td>\n",
" <td>2.81582324473</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" <td>2.69169595482</td>\n",
" <td>2.69169595482</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>2.50200311168</td>\n",
" <td>2.50200311168</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" <td>2.6247486278</td>\n",
" <td>2.6247486278</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" <td>2.81570330755</td>\n",
" <td>2.81570330755</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" <td>2.83086784147</td>\n",
" <td>2.83086784147</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" <td>2.72740469586</td>\n",
" <td>2.72740469586</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" <td>2.85187680353</td>\n",
" <td>2.85187680353</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" <td>2.78389252046</td>\n",
" <td>2.78389252046</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" <td>2.66838820009</td>\n",
" <td>2.66838820009</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" <td>2.75059751133</td>\n",
" <td>2.75059751133</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" <td>2.81202773901</td>\n",
" <td>2.81202773901</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>2.7063957338</td>\n",
" <td>2.7063957338</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" <td>2.77287830588</td>\n",
" <td>2.77287830588</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" <td>2.54391609242</td>\n",
" <td>2.54391609242</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>2.66815832159</td>\n",
" <td>2.66815832159</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>2.56263625769</td>\n",
" <td>2.56263625769</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" <td>2.66310097926</td>\n",
" <td>2.66310097926</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15, 2.69859233281006, 2.69859233281006),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7, 2.52978851455099, 2.52978851455099),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9, 2.81582324473145, 2.81582324473145),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10, 2.69169595481507, 2.69169595481507),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7, 2.50200311168232, 2.50200311168232),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8, 2.6247486277972, 2.6247486277972),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20, 2.81570330754538, 2.81570330754538),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16, 2.83086784146599, 2.83086784146599),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9, 2.72740469585745, 2.72740469585745),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19, 2.85187680352574, 2.85187680352574),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14, 2.7838925204583, 2.7838925204583),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10, 2.66838820009033, 2.66838820009033),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11, 2.75059751133156, 2.75059751133156),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10, 2.81202773901432, 2.81202773901432),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10, 2.7063957337977, 2.7063957337977),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12, 2.77287830587759, 2.77287830587759),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7, 2.54391609242204, 2.54391609242204),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10, 2.66815832158905, 2.66815832158905),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7, 2.56263625768764, 2.56263625768764),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9, 2.6631009792565, 2.6631009792565)]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_regr;\n",
"\n",
"SELECT madlib.svm_predict('abalone_svm_regression',\n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_regr');\n",
"\n",
"SELECT * FROM abalone JOIN abalone_regr USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RMS error:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>9.08842725553</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(9.08842725552861,)]"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Train using Gaussian model"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707]</td>\n",
" <td>2.66850539542</td>\n",
" <td>0.974400795364</td>\n",
" <td>163</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707], 2.66850539541894, 0.97440079536379, 163, 20L, 0L, [None])]"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, abalone_svm_gaussian_regression_random;\n",
"\n",
"SELECT madlib.svm_regression( 'abalone',\n",
" 'abalone_svm_gaussian_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]',\n",
" 'gaussian',\n",
" 'n_components=10',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_gaussian_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Predict using Gaussian model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" <td>9.9302009808</td>\n",
" <td>9.9302009808</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" <td>9.87712610207</td>\n",
" <td>9.87712610207</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" <td>10.0459812729</td>\n",
" <td>10.0459812729</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" <td>10.018415777</td>\n",
" <td>10.018415777</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>9.81382643977</td>\n",
" <td>9.81382643977</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" <td>9.973725783</td>\n",
" <td>9.973725783</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" <td>10.1032556038</td>\n",
" <td>10.1032556038</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" <td>10.0140320794</td>\n",
" <td>10.0140320794</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" <td>10.0478657373</td>\n",
" <td>10.0478657373</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" <td>10.0698224494</td>\n",
" <td>10.0698224494</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" <td>10.1259635318</td>\n",
" <td>10.1259635318</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" <td>9.97481060063</td>\n",
" <td>9.97481060063</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" <td>10.0805427887</td>\n",
" <td>10.0805427887</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" <td>10.107947317</td>\n",
" <td>10.107947317</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>9.97781238334</td>\n",
" <td>9.97781238334</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" <td>10.0409088715</td>\n",
" <td>10.0409088715</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" <td>9.8548093316</td>\n",
" <td>9.8548093316</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>9.96407219215</td>\n",
" <td>9.96407219215</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>9.83873423654</td>\n",
" <td>9.83873423654</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" <td>10.0003544239</td>\n",
" <td>10.0003544239</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15, 9.93020098079582, 9.93020098079582),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7, 9.87712610207203, 9.87712610207203),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9, 10.045981272917, 10.045981272917),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10, 10.0184157770077, 10.0184157770077),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7, 9.81382643976989, 9.81382643976989),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8, 9.97372578299521, 9.97372578299521),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20, 10.1032556037805, 10.1032556037805),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16, 10.0140320794144, 10.0140320794144),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9, 10.0478657373155, 10.0478657373155),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19, 10.0698224493735, 10.0698224493735),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14, 10.1259635317559, 10.1259635317559),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10, 9.97481060062509, 9.97481060062509),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11, 10.0805427887436, 10.0805427887436),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10, 10.107947317027, 10.107947317027),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10, 9.97781238333585, 9.97781238333585),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12, 10.0409088715201, 10.0409088715201),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7, 9.85480933160473, 9.85480933160473),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10, 9.96407219215287, 9.96407219215287),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7, 9.83873423654298, 9.83873423654298),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9, 10.0003544238551, 10.0003544238551)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_gaussian_regr;\n",
"\n",
"SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_gaussian_regr');\n",
"\n",
"SELECT * FROM abalone JOIN abalone_gaussian_regr USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the RMS error. Note this produces a more accurate result than the linear case for this small data set:"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3.84176368344</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3.84176368343915,)]"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_gaussian_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Cross validation\n",
"Let's run cross validation for different initial step sizes and lambda values:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707]</td>\n",
" <td>2.63941855054</td>\n",
" <td>1.07622244533</td>\n",
" <td>163</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([4.49016341280977, 2.19062972461334, -2.04673653356154, 1.11216153651262, 2.83478599238881, -4.23122821845785, 4.17684533744501, -5.36892552740644, 0.775782561685621, -3.62606941016707], 2.63941855054256, 1.07622244533275, 163, 20L, 0L, [None])]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, \n",
"abalone_svm_gaussian_regression_random, abalone_svm_gaussian_regression_cv;\n",
"\n",
"SELECT madlib.svm_regression( 'abalone',\n",
" 'abalone_svm_gaussian_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]',\n",
" 'gaussian',\n",
" 'n_components=10',\n",
" '',\n",
" 'init_stepsize=[0.01,1], n_folds=3, max_iter=200, lambda=[0.01, 0.1, 0.5], validation_result=abalone_svm_gaussian_regression_cv'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_gaussian_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the summary table showing the final model parameters are those that produced \n",
"the lowest error in the cross validation runs:"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>method</th>\n",
" <th>version_number</th>\n",
" <th>source_table</th>\n",
" <th>model_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>kernel_func</th>\n",
" <th>kernel_params</th>\n",
" <th>grouping_col</th>\n",
" <th>optim_params</th>\n",
" <th>reg_params</th>\n",
" <th>num_all_groups</th>\n",
" <th>num_failed_groups</th>\n",
" <th>total_rows_processed</th>\n",
" <th>total_rows_skipped</th>\n",
" </tr>\n",
" <tr>\n",
" <td>SVR</td>\n",
" <td>1.15-dev</td>\n",
" <td>abalone</td>\n",
" <td>abalone_svm_gaussian_regression</td>\n",
" <td>rings</td>\n",
" <td>ARRAY[1, length, diameter, height]</td>\n",
" <td>gaussian</td>\n",
" <td>gamma=0.25, n_components=10,random_state=1, fit_intercept=False, fit_in_memory=True</td>\n",
" <td>NULL</td>\n",
" <td> init_stepsize=1.0,<br> decay_factor=0.9,<br> max_iter=200,<br> tolerance=1e-10,<br> epsilon=0.01,<br> eps_table=,<br> class_weight=<br> </td>\n",
" <td>lambda=0.01, norm=l2, n_folds=3</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'SVR', u'1.15-dev', u'abalone', u'abalone_svm_gaussian_regression', u'rings', u'ARRAY[1, length, diameter, height]', u'gaussian', u'gamma=0.25, n_components=10,random_state=1, fit_intercept=False, fit_in_memory=True', u'NULL', u' init_stepsize=1.0,\\n decay_factor=0.9,\\n max_iter=200,\\n tolerance=1e-10,\\n epsilon=0.01,\\n eps_table=,\\n class_weight=\\n ', u'lambda=0.01, norm=l2, n_folds=3', 1, 0, 20L, 0L)]"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql SELECT * FROM abalone_svm_gaussian_regression_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the values for cross validation:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>init_stepsize</th>\n",
" <th>lambda</th>\n",
" <th>mean_score</th>\n",
" <th>std_dev_score</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.01</td>\n",
" <td>-4.06711568585</td>\n",
" <td>0.435966381366</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.1</td>\n",
" <td>-4.08068428345</td>\n",
" <td>0.44660797513</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.5</td>\n",
" <td>-4.52576046087</td>\n",
" <td>0.20597876382</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>-11.0231044189</td>\n",
" <td>0.739956548721</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.1</td>\n",
" <td>-11.0244799274</td>\n",
" <td>0.740029346709</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.5</td>\n",
" <td>-11.0305445077</td>\n",
" <td>0.740350338532</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('1.0'), Decimal('0.01'), Decimal('-4.06711568585'), Decimal('0.435966381366')),\n",
" (Decimal('1.0'), Decimal('0.1'), Decimal('-4.08068428345'), Decimal('0.44660797513')),\n",
" (Decimal('1.0'), Decimal('0.5'), Decimal('-4.52576046087'), Decimal('0.20597876382')),\n",
" (Decimal('0.01'), Decimal('0.01'), Decimal('-11.0231044189'), Decimal('0.739956548721')),\n",
" (Decimal('0.01'), Decimal('0.1'), Decimal('-11.0244799274'), Decimal('0.740029346709')),\n",
" (Decimal('0.01'), Decimal('0.5'), Decimal('-11.0305445077'), Decimal('0.740350338532'))]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_svm_gaussian_regression_cv;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7. Predict using cross-validated Gaussian regression model:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>svm_predict</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_gaussian_regr;\n",
"SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_gaussian_regr');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the RMS error. Note this produces a more accurate result than the previous run with the Gaussian kernel:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3.84176368344</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3.84176368343915,)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_gaussian_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Novelty detection \n",
"# 1. Train a non-linear one-class SVM\n",
"Use a Gaussian kernel using the housing data set. Note that the dependent variable is not a parameter for one-class:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, -0.0549714494076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694]</td>\n",
" <td>0.944016313708</td>\n",
" <td>14.5271059047</td>\n",
" <td>100</td>\n",
" <td>16</td>\n",
" <td>-1</td>\n",
" <td>[-1.0, 1.0]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, -0.0549714494076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694], 0.944016313708205, 14.5271059047443, 100, 16L, -1L, [-1.0, 1.0])]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_one_class_gaussian, houses_one_class_gaussian_summary, houses_one_class_gaussian_random;\n",
"\n",
"SELECT madlib.svm_one_class('houses',\n",
" 'houses_one_class_gaussian',\n",
" 'ARRAY[1,tax,bedroom,bath,size,lot,price]',\n",
" 'gaussian',\n",
" 'gamma=0.5,n_components=55, random_state=3',\n",
" NULL,\n",
" 'max_iter=100, init_stepsize=10,lambda=10, tolerance=0'\n",
" );\n",
"\n",
"SELECT * FROM houses_one_class_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Create test data\n",
"For the novelty detection using one-class, let's create a test data set using the last 3 values from the training set plus an outlier at the end (10x price):"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"4 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>650000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 3100, 3, 2.0, 140000, 1760, 38000),\n",
" (2, 2070, 2, 3.0, 148000, 1550, 14000),\n",
" (3, 650, 3, 1.5, 65000, 1450, 12000),\n",
" (4, 650, 3, 1.5, 650000, 1450, 12000)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_one_class_test;\n",
"\n",
"CREATE TABLE houses_one_class_test (id INT, tax INT, bedroom INT, bath FLOAT, price INT,\n",
" size INT, lot INT);\n",
"\n",
"INSERT INTO houses_one_class_test VALUES \n",
" (1 , 3100 , 3 , 2 , 140000 , 1760 , 38000),\n",
" (2 , 2070 , 2 , 3 , 148000 , 1550 , 14000),\n",
" (3 , 650 , 3 , 1.5 , 65000 , 1450 , 12000),\n",
" (4 , 650 , 3 , 1.5 , 650000 , 1450 , 12000);\n",
" \n",
"SELECT * FROM houses_one_class_test ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using Gaussian one-class novelty detection model\n",
"Result shows the last row predicted to be novel:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>1.0</td>\n",
" <td>0.0662278474212</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>1.0</td>\n",
" <td>0.092124936453</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>1.0</td>\n",
" <td>0.03415206006</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>650000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>-1.0</td>\n",
" <td>-0.0131918729845</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 3100, 3, 2.0, 140000, 1760, 38000, 1.0, 0.066227847421185),\n",
" (2, 2070, 2, 3.0, 148000, 1550, 14000, 1.0, 0.0921249364529948),\n",
" (3, 650, 3, 1.5, 65000, 1450, 12000, 1.0, 0.0341520600599523),\n",
" (4, 650, 3, 1.5, 650000, 1450, 12000, -1.0, -0.0131918729845241)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS houses_pred;\n",
"\n",
"SELECT madlib.svm_predict('houses_one_class_gaussian', \n",
" 'houses_one_class_test', \n",
" 'id', \n",
" 'houses_pred');\n",
"\n",
"SELECT * FROM houses_one_class_test JOIN houses_pred USING (id) ORDER BY id;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}