blob: 471da2c11401cf7f3c9c0c1d8729537ee1c8ada1 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Support Vector Machines\n",
"Support Vector Machines (SVMs) are models for regression and classification tasks. SVM models have two particularly desirable features: robustness in the presence of noisy data and applicability to a variety of data configurations. At its core, a linear SVM model is a hyperplane separating two distinct classes of data (in the case of classification problems), in such a way that the distance between the hyperplane and the nearest training data point (called the margin) is maximized. Vectors that lie on this margin are called support vectors. With the support vectors fixed, perturbations of vectors beyond the margin will not affect the model; this contributes to the model’s robustness. By substituting a kernel function for the usual inner product, one can approximate a large variety of decision boundaries in addition to linear hyperplanes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
"#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
"\n",
"# Greenplum Database 5.x on GCP (PM demo machine) - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.17-dev, git revision: rel/v1.16-29-g9fa27e5, cmake configuration time: Mon Oct 7 17:04:14 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-29-g9fa27e5, cmake configuration time: Mon Oct 7 17:04:14 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Classification\n",
"# 1. Create input data set"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"15 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS houses;\n",
"\n",
"CREATE TABLE houses (id INT, tax INT, bedroom INT, bath FLOAT, price INT,\n",
" size INT, lot INT);\n",
"\n",
"INSERT INTO houses VALUES \n",
" (1 , 590 , 2 , 1 , 50000 , 770 , 22100),\n",
" (2 , 1050 , 3 , 2 , 85000 , 1410 , 12000),\n",
" (3 , 20 , 3 , 1 , 22500 , 1060 , 3500),\n",
" (4 , 870 , 2 , 2 , 90000 , 1300 , 17500),\n",
" (5 , 1320 , 3 , 2 , 133000 , 1500 , 30000),\n",
" (6 , 1350 , 2 , 1 , 90500 , 820 , 25700),\n",
" (7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000),\n",
" (8 , 680 , 2 , 1 , 142500 , 1170 , 22000),\n",
" (9 , 1840 , 3 , 2 , 160000 , 1500 , 19000),\n",
" (10 , 3680 , 4 , 2 , 240000 , 2790 , 20000),\n",
" (11 , 1660 , 3 , 1 , 87000 , 1030 , 17500),\n",
" (12 , 1620 , 3 , 2 , 118600 , 1250 , 20000),\n",
" (13 , 3100 , 3 , 2 , 140000 , 1760 , 38000),\n",
" (14 , 2070 , 2 , 3 , 148000 , 1550 , 14000),\n",
" (15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000);\n",
" \n",
"SELECT * FROM houses ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# 2. Train linear classification model\n",
"Categorical variable is price < $100,0000."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.12229100715556, -0.00311209904999331, 0.0729255891679728, 0.00159299038324124]</td>\n",
" <td>0.724409814168</td>\n",
" <td>4412.03185362</td>\n",
" <td>100</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.12229100715556, -0.00311209904999331, 0.0729255891679728, 0.00159299038324124], 0.724409814168392, 4412.03185361608, 100, 15L, 0L, [False, True])]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm, houses_svm_summary;\n",
"\n",
"SELECT madlib.svm_classification('houses',\n",
" 'houses_svm',\n",
" 'price < 100000',\n",
" 'ARRAY[1, tax, bath, size]'\n",
" );\n",
"SELECT * FROM houses_svm;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using linear model\n",
"We want to predict if house price is less than $100,000. We use the training data set for prediction as well, which is not usual but serves to show the syntax. The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" <th>actual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>False</td>\n",
" <td>-0.414319248077</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>False</td>\n",
" <td>-0.753445376631</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>True</td>\n",
" <td>1.82154442156</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>False</td>\n",
" <td>-0.368496489789</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>False</td>\n",
" <td>-1.45034298564</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>False</td>\n",
" <td>-2.69986500691</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>False</td>\n",
" <td>-4.9850818531</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>False</td>\n",
" <td>-0.0572120092797</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>False</td>\n",
" <td>-3.06863449163</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-6.73993914924</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>False</td>\n",
" <td>-3.33008773193</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-2.78222029645</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>False</td>\n",
" <td>-6.57570179498</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>False</td>\n",
" <td>-3.6318421648</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>0.518651064112</td>\n",
" <td>True</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100, False, -0.414319248076766, True),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000, False, -0.753445376631322, True),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.82154442155938, True),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500, False, -0.368496489789063, True),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -1.45034298563781, False),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700, False, -2.69986500690962, True),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -4.98508185310201, False),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000, False, -0.0572120092796671, False),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -3.06863449163433, False),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -6.73993914924082, False),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500, False, -3.33008773192689, True),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -2.78222029644611, False),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -6.57570179498318, False),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -3.63184216480275, False),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000, True, 0.518651064111666, True)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_pred;\n",
"\n",
"SELECT madlib.svm_predict('houses_svm', \n",
" 'houses', \n",
" 'id', \n",
" 'houses_pred');\n",
"\n",
"SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count the miss-classifications:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(5L,)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM houses_pred JOIN houses USING (id) \n",
"WHERE houses_pred.prediction != (houses.price < 100000);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Train using Gaussian kernel\n",
"Next generate a nonlinear model using a Gaussian kernel. This time we specify the initial step size and maximum number of iterations to run. As part of the kernel parameter, we choose 10 as the dimension of the space where we train SVM. A larger number will lead to a more powerful model but run the risk of overfitting. As a result, the model will be a 10 dimensional vector, instead of 4 as in the case of linear model."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.174449870282045, 0.221949803360671, -0.402333071003166, -0.262676154168673, 1.13130014096269, -0.477266861255365, 0.831261399413409, 0.603981268320905, 0.470637538053085, 0.504604196295102, 0.0285966246226131, -0.405297458067235, 0.738052957351954, 0.0144918647607145, 0.216536544966497, 0.320016273243399, 0.00232337917551938, -0.0843958273523198, -0.298021033404343, 0.879347233724568, 0.116421698629981, 0.376803883069554, 0.46563131778655, -0.203321710241967, -0.236572471896519, 0.383255178998929, -0.492255525590516, 0.983589649060747, 0.236964019247516, 0.416962283463462, -0.385800872437151, -0.345329823794884, 0.120732524424864, 0.710780837147856, 0.264794396724802, -0.489555512573859, 0.380902730877136, 0.0536009019790365, 0.73200275671809, 0.542826436375729, 0.168712544719772, -0.132541621427206, -0.214051414056401, 0.219073863854421, 0.730947258770631, 0.162460261370712, -0.144429756928593, -0.935475925932472, -0.0297216095002956, 0.193381904334028, 0.448848607315711, -1.18721596180884, 0.482417336135806, -0.52571602337991, -0.818120577090183, 0.260984129515484, -0.557576024742713, -0.83770700093059, -0.253588929795698, -0.463277449925523, -0.0810102775249212, -0.181539340380399, -0.21854707989751, -0.168502090717447, 0.457147867346449, 0.00313147461640386, 0.468228657916444, 0.295498231571813, -0.0862639643663198, -0.252594854475285, 0.732161837953837, 0.621826163021027, 0.821103686828688, -0.0698709773990468, 0.320661634243996, -0.518205856242166, 0.175088291193444, 0.464641905215399, 0.546148979761043, 0.459529900686438, 0.846112959072494, 0.438090206970192, -0.394317794019626, -0.421641334028624, 0.0886399404180068, -0.366570124483957, -0.298487191729675, 0.0355220264779079, -0.009380543657684, 0.510081333996271, -0.368854827597419, 0.216087166764326, -0.411797703281369, -0.163229347789594, 0.229469291437079, -0.164150670617032, 0.0745696677770123, -0.452725636607896, -0.861040652904235, 0.342036947159568]</td>\n",
" <td>0.00722931656529</td>\n",
" <td>0.0465703228418</td>\n",
" <td>177</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.174449870282045, 0.221949803360671, -0.402333071003166, -0.262676154168673, 1.13130014096269, -0.477266861255365, 0.831261399413409, 0.603981268320905, 0.470637538053085, 0.504604196295102, 0.0285966246226131, -0.405297458067235, 0.738052957351954, 0.0144918647607145, 0.216536544966497, 0.320016273243399, 0.00232337917551938, -0.0843958273523198, -0.298021033404343, 0.879347233724568, 0.116421698629981, 0.376803883069554, 0.46563131778655, -0.203321710241967, -0.236572471896519, 0.383255178998929, -0.492255525590516, 0.983589649060747, 0.236964019247516, 0.416962283463462, -0.385800872437151, -0.345329823794884, 0.120732524424864, 0.710780837147856, 0.264794396724802, -0.489555512573859, 0.380902730877136, 0.0536009019790365, 0.73200275671809, 0.542826436375729, 0.168712544719772, -0.132541621427206, -0.214051414056401, 0.219073863854421, 0.730947258770631, 0.162460261370712, -0.144429756928593, -0.935475925932472, -0.0297216095002956, 0.193381904334028, 0.448848607315711, -1.18721596180884, 0.482417336135806, -0.52571602337991, -0.818120577090183, 0.260984129515484, -0.557576024742713, -0.83770700093059, -0.253588929795698, -0.463277449925523, -0.0810102775249212, -0.181539340380399, -0.21854707989751, -0.168502090717447, 0.457147867346449, 0.00313147461640386, 0.468228657916444, 0.295498231571813, -0.0862639643663198, -0.252594854475285, 0.732161837953837, 0.621826163021027, 0.821103686828688, -0.0698709773990468, 0.320661634243996, -0.518205856242166, 0.175088291193444, 0.464641905215399, 0.546148979761043, 0.459529900686438, 0.846112959072494, 0.438090206970192, -0.394317794019626, -0.421641334028624, 0.0886399404180068, -0.366570124483957, -0.298487191729675, 0.0355220264779079, -0.009380543657684, 0.510081333996271, -0.368854827597419, 0.216087166764326, -0.411797703281369, -0.163229347789594, 0.229469291437079, -0.164150670617032, 0.0745696677770123, -0.452725636607896, -0.861040652904235, 0.342036947159568], 0.00722931656528619, 0.0465703228417611, 177, 15L, 0L, [False, True])]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n",
"\n",
"SELECT madlib.svm_classification( 'houses',\n",
" 'houses_svm_gaussian',\n",
" 'price < 100000',\n",
" 'ARRAY[1, tax, bath, size]',\n",
" 'gaussian',\n",
" 'n_components=100',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200'\n",
" );\n",
"\n",
"SELECT * FROM houses_svm_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Predict using Gaussian model\n",
"The predicted results are in the \"prediction\" column and the actual data is in the \"actual\" column."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"15 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" <th>actual</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>50000</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>True</td>\n",
" <td>1.16840726459</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>85000</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>1.12629659629</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>22500</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>True</td>\n",
" <td>1.03236268569</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>90000</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>True</td>\n",
" <td>1.17721999735</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>133000</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>False</td>\n",
" <td>-1.38210027041</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>90500</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>True</td>\n",
" <td>1.16225381842</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>260000</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>False</td>\n",
" <td>-1.28234134504</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>142500</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>False</td>\n",
" <td>-1.14421198085</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>160000</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>False</td>\n",
" <td>-1.04766021862</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>240000</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-1.0549708301</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>87000</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>True</td>\n",
" <td>1.17827790368</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>118600</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>False</td>\n",
" <td>-1.74452909954</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>False</td>\n",
" <td>-1.11954973223</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>False</td>\n",
" <td>-1.0382466233</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>True</td>\n",
" <td>1.29000398657</td>\n",
" <td>True</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 50000, 770, 22100, True, 1.16840726459313, True),\n",
" (2, 1050, 3, 2.0, 85000, 1410, 12000, True, 1.12629659629046, True),\n",
" (3, 20, 3, 1.0, 22500, 1060, 3500, True, 1.03236268569294, True),\n",
" (4, 870, 2, 2.0, 90000, 1300, 17500, True, 1.1772199973482, True),\n",
" (5, 1320, 3, 2.0, 133000, 1500, 30000, False, -1.38210027041179, False),\n",
" (6, 1350, 2, 1.0, 90500, 820, 25700, True, 1.16225381842323, True),\n",
" (7, 2790, 3, 2.5, 260000, 2130, 25000, False, -1.28234134503531, False),\n",
" (8, 680, 2, 1.0, 142500, 1170, 22000, False, -1.14421198085148, False),\n",
" (9, 1840, 3, 2.0, 160000, 1500, 19000, False, -1.04766021861739, False),\n",
" (10, 3680, 4, 2.0, 240000, 2790, 20000, False, -1.05497083009726, False),\n",
" (11, 1660, 3, 1.0, 87000, 1030, 17500, True, 1.17827790367511, True),\n",
" (12, 1620, 3, 2.0, 118600, 1250, 20000, False, -1.74452909953604, False),\n",
" (13, 3100, 3, 2.0, 140000, 1760, 38000, False, -1.11954973222788, False),\n",
" (14, 2070, 2, 3.0, 148000, 1550, 14000, False, -1.03824662329829, False),\n",
" (15, 650, 3, 1.5, 65000, 1450, 12000, True, 1.29000398656605, True)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_pred_gaussian;\n",
"\n",
"SELECT madlib.svm_predict('houses_svm_gaussian', \n",
" 'houses', \n",
" 'id', \n",
" 'houses_pred_gaussian');\n",
"\n",
"SELECT *, price < 100000 AS actual FROM houses JOIN houses_pred_gaussian USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count the miss classifications. Note this produces a more accurate result than the linear case for this small data set:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0L,)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM houses_pred_gaussian JOIN houses USING (id) \n",
"WHERE houses_pred_gaussian.prediction != (houses.price < 100000);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Balancing data sets\n",
"In the case of an unbalanced class-size dataset, use the 'balanced' parameter to classify when building the model:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.0304283785262617, -0.31252256022358, -0.786594823395088, -0.213462165447043, -0.205921127201806, -0.160495098525847, -0.0572558143143236, 0.183220396715039, 0.464792217394048, 0.481483244258389, -0.517643614876286, -0.978683507901074, 0.16817691490251, -0.646816511296802, 0.186375390577986, 0.680958110801916, -1.08232636410793, -0.945175946274317, -0.491574430145303, -0.0814842224959727, -0.0190008166655649, 0.0481772284935738, -0.163702000177582, -0.674004897487282, 0.113518490341767, -0.638187403343937, 0.526552063250668, -0.274100343388661, 0.354431317955514, -0.428444014517539, 0.0946683130713131, -0.239646558966188, -0.288975110225114, 0.277634287723891, 0.109083762491799, -0.590472152297871, 0.30239084357163, -0.644378259476824, 0.518616701508965, -0.0310448850251757, -0.0616074328876328, -0.815709238025655, 0.533952545147382, -0.27885652806791, -0.169816368423772, 0.501969605761016, -0.0453904283532783, -0.296542126733526, -0.6744641877468, -1.04295422716397, 0.0998805368884473, -0.0581387461992054, 0.226951693444486, 0.183293767588985, -0.506636260378821, -0.182587120340722, -0.632707861890101, 0.165980897095258, -0.918139789600548, -0.770637770944717, 0.945310986017951, -1.02669717302724, 0.608102258578773, 0.22613081465608, 0.141010992575908, -0.154732938082785, 0.673276713057085, -0.252979879432229, 0.373696371450733, -0.204550990421942, 0.934207962184333, -0.306025798509274, 0.812798367835215, -0.455532753255022, 0.125622634105054, 0.367604443133276, 0.55994486569804, 0.0753886324541624, 0.77813576524246, -0.166034479090781, 0.94565093122238, -0.296006132277826, -0.265234940757184, 0.416497397310383, 0.731437706958231, 0.270984781299447, -0.428663581410429, -0.312686040931128, -0.773226583138831, 0.598583108690796, -1.09057460676235, -0.685811206980829, 0.320672030715493, -0.216001011689699, -0.467434588986512, -0.412477454386901, 0.437774085441529, 0.177397734130052, -0.621856758100873, 0.965785439434402]</td>\n",
" <td>0.00892975733073</td>\n",
" <td>0.0517583539075</td>\n",
" <td>177</td>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>[False, True]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.0304283785262617, -0.31252256022358, -0.786594823395088, -0.213462165447043, -0.205921127201806, -0.160495098525847, -0.0572558143143236, 0.183220396715039, 0.464792217394048, 0.481483244258389, -0.517643614876286, -0.978683507901074, 0.16817691490251, -0.646816511296802, 0.186375390577986, 0.680958110801916, -1.08232636410793, -0.945175946274317, -0.491574430145303, -0.0814842224959727, -0.0190008166655649, 0.0481772284935738, -0.163702000177582, -0.674004897487282, 0.113518490341767, -0.638187403343937, 0.526552063250668, -0.274100343388661, 0.354431317955514, -0.428444014517539, 0.0946683130713131, -0.239646558966188, -0.288975110225114, 0.277634287723891, 0.109083762491799, -0.590472152297871, 0.30239084357163, -0.644378259476824, 0.518616701508965, -0.0310448850251757, -0.0616074328876328, -0.815709238025655, 0.533952545147382, -0.27885652806791, -0.169816368423772, 0.501969605761016, -0.0453904283532783, -0.296542126733526, -0.6744641877468, -1.04295422716397, 0.0998805368884473, -0.0581387461992054, 0.226951693444486, 0.183293767588985, -0.506636260378821, -0.182587120340722, -0.632707861890101, 0.165980897095258, -0.918139789600548, -0.770637770944717, 0.945310986017951, -1.02669717302724, 0.608102258578773, 0.22613081465608, 0.141010992575908, -0.154732938082785, 0.673276713057085, -0.252979879432229, 0.373696371450733, -0.204550990421942, 0.934207962184333, -0.306025798509274, 0.812798367835215, -0.455532753255022, 0.125622634105054, 0.367604443133276, 0.55994486569804, 0.0753886324541624, 0.77813576524246, -0.166034479090781, 0.94565093122238, -0.296006132277826, -0.265234940757184, 0.416497397310383, 0.731437706958231, 0.270984781299447, -0.428663581410429, -0.312686040931128, -0.773226583138831, 0.598583108690796, -1.09057460676235, -0.685811206980829, 0.320672030715493, -0.216001011689699, -0.467434588986512, -0.412477454386901, 0.437774085441529, 0.177397734130052, -0.621856758100873, 0.965785439434402], 0.00892975733073064, 0.0517583539075499, 177, 15L, 0L, [False, True])]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_svm_gaussian, houses_svm_gaussian_summary, houses_svm_gaussian_random;\n",
"\n",
"SELECT madlib.svm_classification( 'houses',\n",
" 'houses_svm_gaussian',\n",
" 'price < 150000',\n",
" 'ARRAY[1, tax, bath, size]',\n",
" 'gaussian',\n",
" 'n_components=100',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200, class_weight=balanced'\n",
" );\n",
"\n",
"SELECT * FROM houses_svm_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Regression\n",
"# 1. Create input data set\n",
"For regression we use part of the well known abalone data set https://archive.ics.uci.edu/ml/datasets/abalone :"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"20 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone;\n",
"\n",
"CREATE TABLE abalone (id INT, sex TEXT, length FLOAT, diameter FLOAT, height FLOAT, rings INT);\n",
"\n",
"INSERT INTO abalone VALUES\n",
"(1,'M',0.455,0.365,0.095,15),\n",
"(2,'M',0.35,0.265,0.09,7),\n",
"(3,'F',0.53,0.42,0.135,9),\n",
"(4,'M',0.44,0.365,0.125,10),\n",
"(5,'I',0.33,0.255,0.08,7),\n",
"(6,'I',0.425,0.3,0.095,8),\n",
"(7,'F',0.53,0.415,0.15,20),\n",
"(8,'F',0.545,0.425,0.125,16),\n",
"(9,'M',0.475,0.37,0.125,9),\n",
"(10,'F',0.55,0.44,0.15,19),\n",
"(11,'F',0.525,0.38,0.14,14),\n",
"(12,'M',0.43,0.35,0.11,10),\n",
"(13,'M',0.49,0.38,0.135,11),\n",
"(14,'F',0.535,0.405,0.145,10),\n",
"(15,'F',0.47,0.355,0.1,10),\n",
"(16,'M',0.5,0.4,0.13,12),\n",
"(17,'I',0.355,0.28,0.085,7),\n",
"(18,'F',0.44,0.34,0.1,10),\n",
"(19,'M',0.365,0.295,0.08,7),\n",
"(20,'M',0.45,0.32,0.1,9);\n",
"\n",
"SELECT * FROM abalone ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Train linear regression model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[1.998949892503, 0.918517335088235, 0.712125758488304, 0.229379426728093]</td>\n",
" <td>8.29033306424</td>\n",
" <td>23.2251777867</td>\n",
" <td>100</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([1.998949892503, 0.918517335088235, 0.712125758488304, 0.229379426728093], 8.2903330642386, 23.2251777867403, 100, 20L, 0L, [None])]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_regression, abalone_svm_regression_summary;\n",
"\n",
"SELECT madlib.svm_regression('abalone',\n",
" 'abalone_svm_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using linear model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" <td>2.69859222736</td>\n",
" <td>2.69859222736</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" <td>2.52978843419</td>\n",
" <td>2.52978843419</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" <td>2.81582312127</td>\n",
" <td>2.81582312127</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" <td>2.69169585013</td>\n",
" <td>2.69169585013</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>2.50200303563</td>\n",
" <td>2.50200303563</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" <td>2.624748533</td>\n",
" <td>2.624748533</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" <td>2.81570318388</td>\n",
" <td>2.81570318388</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" <td>2.83086771582</td>\n",
" <td>2.83086771582</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" <td>2.72740458565</td>\n",
" <td>2.72740458565</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" <td>2.85187667455</td>\n",
" <td>2.85187667455</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" <td>2.78389240139</td>\n",
" <td>2.78389240139</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" <td>2.668388099</td>\n",
" <td>2.668388099</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" <td>2.75059739753</td>\n",
" <td>2.75059739753</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" <td>2.81202761584</td>\n",
" <td>2.81202761584</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>2.70639562693</td>\n",
" <td>2.70639562693</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" <td>2.77287818892</td>\n",
" <td>2.77287818892</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" <td>2.54391601011</td>\n",
" <td>2.54391601011</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>2.6681582205</td>\n",
" <td>2.6681582205</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>2.5626361727</td>\n",
" <td>2.5626361727</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" <td>2.66310087868</td>\n",
" <td>2.66310087868</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15, 2.69859222735555, 2.69859222735555),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7, 2.52978843418882, 2.52978843418882),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9, 2.81582312127315, 2.81582312127315),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10, 2.69169585013107, 2.69169585013107),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7, 2.50200303563489, 2.50200303563489),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8, 2.62474853300116, 2.62474853300116),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20, 2.81570318388163, 2.81570318388163),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16, 2.83086771582463, 2.83086771582463),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9, 2.7274045856516, 2.7274045856516),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19, 2.8518766745456, 2.8518766745456),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14, 2.78389240139182, 2.78389240139182),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10, 2.66838809900194, 2.66838809900194),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11, 2.75059739753009, 2.75059739753009),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10, 2.81202761583855, 2.81202761583855),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10, 2.70639562693063, 2.70639562693063),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12, 2.7728781889171, 2.7728781889171),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7, 2.54391601010794, 2.54391601010794),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10, 2.66815822050066, 2.66815822050066),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7, 2.56263617270251, 2.56263617270251),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9, 2.66310087868178, 2.66310087868178)]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_regr;\n",
"\n",
"SELECT madlib.svm_predict('abalone_svm_regression',\n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_regr');\n",
"\n",
"SELECT * FROM abalone JOIN abalone_regr USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RMS error:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>9.08842735715</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(9.08842735715349,)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Train using Gaussian model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[1.43663248508461, 0.69548130777053, -0.667820070034776, 0.353629960777887, 0.90572382703018, -1.35192985512522, 1.33643594466947, -1.71808935250261, 0.244907459899655, -1.16083721399717, 0.529533556029462, -1.47565163746093, 0.980976366290674, -1.57625949227502, 0.725799354053906, 0.621521219305506, -0.460388482856766, 1.62409092636038, -1.57065529672944, -1.04776403679369, -1.57215274912702, 0.414786913802637, 0.338630545155005, 1.67994043684046, 0.510271334529243, 1.56926726571838, -0.266055554149799, -1.58512365036022, -0.537619986488822, -1.53123126187998, 0.432059794750374, 1.10164971751016, -1.63470052461906, 1.37210730848619, 0.69195070273926, -1.51550206095593, -0.879588483736812, -0.37806986511409, -1.02110166282495, -0.905287480716121, -0.0893285495393008, -0.576435902131671, 1.24002228080107, -0.817209795372, 0.522281410113176, 0.156847005593254, -0.872148823590253, -1.43633009947157, -0.804330453999928, -0.926921599832849, -0.211959168386152, -0.804358663674064, -0.616861771161596, 1.74018724759497, -1.16264563996159, -1.63930487981442, -1.82413119548892, 0.794011977984717, 1.1698395008002, 1.61482847146405, 0.205981267231647, -0.458375714158526, -0.867267996489895, -0.308286903981698, 1.58324830425497, -0.827831442995643, -1.47580575745441, -0.445332216554912, 1.16538304159597, 1.37221102799479, 1.57289281402231, -0.66649715533837, -0.476569227105178, 0.29665562590112, 1.15590670076679, 1.46291613083677, 0.79992035125813, 1.09721250793166, -1.39887016649551, 0.190329304832203, 1.28663304442308, 1.42061537940844, -1.03946229472351, -0.556121905835198, 0.341622218455033, -0.899369783669295, 1.61878187535374, 1.57583095826418, -0.71553739367934, -1.59696136184546, -1.56192177099981, -0.174158730591109, 1.19626739396989, -1.37221525222703, 1.59932192137657, 0.435584312299149, 1.38585223624462, 0.449219561014516, -1.61927162148616, -1.76886002780722]</td>\n",
" <td>2.60162532006</td>\n",
" <td>1.04369725716</td>\n",
" <td>166</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([1.43663248508461, 0.69548130777053, -0.667820070034776, 0.353629960777887, 0.90572382703018, -1.35192985512522, 1.33643594466947, -1.71808935250261, 0.244907459899655, -1.16083721399717, 0.529533556029462, -1.47565163746093, 0.980976366290674, -1.57625949227502, 0.725799354053906, 0.621521219305506, -0.460388482856766, 1.62409092636038, -1.57065529672944, -1.04776403679369, -1.57215274912702, 0.414786913802637, 0.338630545155005, 1.67994043684046, 0.510271334529243, 1.56926726571838, -0.266055554149799, -1.58512365036022, -0.537619986488822, -1.53123126187998, 0.432059794750374, 1.10164971751016, -1.63470052461906, 1.37210730848619, 0.69195070273926, -1.51550206095593, -0.879588483736812, -0.37806986511409, -1.02110166282495, -0.905287480716121, -0.0893285495393008, -0.576435902131671, 1.24002228080107, -0.817209795372, 0.522281410113176, 0.156847005593254, -0.872148823590253, -1.43633009947157, -0.804330453999928, -0.926921599832849, -0.211959168386152, -0.804358663674064, -0.616861771161596, 1.74018724759497, -1.16264563996159, -1.63930487981442, -1.82413119548892, 0.794011977984717, 1.1698395008002, 1.61482847146405, 0.205981267231647, -0.458375714158526, -0.867267996489895, -0.308286903981698, 1.58324830425497, -0.827831442995643, -1.47580575745441, -0.445332216554912, 1.16538304159597, 1.37221102799479, 1.57289281402231, -0.66649715533837, -0.476569227105178, 0.29665562590112, 1.15590670076679, 1.46291613083677, 0.79992035125813, 1.09721250793166, -1.39887016649551, 0.190329304832203, 1.28663304442308, 1.42061537940844, -1.03946229472351, -0.556121905835198, 0.341622218455033, -0.899369783669295, 1.61878187535374, 1.57583095826418, -0.71553739367934, -1.59696136184546, -1.56192177099981, -0.174158730591109, 1.19626739396989, -1.37221525222703, 1.59932192137657, 0.435584312299149, 1.38585223624462, 0.449219561014516, -1.61927162148616, -1.76886002780722], 2.60162532005635, 1.04369725715644, 166, 20L, 0L, [None])]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, abalone_svm_gaussian_regression_random;\n",
"\n",
"SELECT madlib.svm_regression( 'abalone',\n",
" 'abalone_svm_gaussian_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]',\n",
" 'gaussian',\n",
" 'n_components=100',\n",
" '',\n",
" 'init_stepsize=1, max_iter=200'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_gaussian_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Predict using Gaussian model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>rings</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>15</td>\n",
" <td>9.98603457721</td>\n",
" <td>9.98603457721</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>M</td>\n",
" <td>0.35</td>\n",
" <td>0.265</td>\n",
" <td>0.09</td>\n",
" <td>7</td>\n",
" <td>9.57759650003</td>\n",
" <td>9.57759650003</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.42</td>\n",
" <td>0.135</td>\n",
" <td>9</td>\n",
" <td>10.2223107905</td>\n",
" <td>10.2223107905</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>M</td>\n",
" <td>0.44</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>10</td>\n",
" <td>9.97689991817</td>\n",
" <td>9.97689991817</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>I</td>\n",
" <td>0.33</td>\n",
" <td>0.255</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>9.49509568417</td>\n",
" <td>9.49509568417</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>I</td>\n",
" <td>0.425</td>\n",
" <td>0.3</td>\n",
" <td>0.095</td>\n",
" <td>8</td>\n",
" <td>9.82953290649</td>\n",
" <td>9.82953290649</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>F</td>\n",
" <td>0.53</td>\n",
" <td>0.415</td>\n",
" <td>0.15</td>\n",
" <td>20</td>\n",
" <td>10.2265451261</td>\n",
" <td>10.2265451261</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>F</td>\n",
" <td>0.545</td>\n",
" <td>0.425</td>\n",
" <td>0.125</td>\n",
" <td>16</td>\n",
" <td>10.2457829829</td>\n",
" <td>10.2457829829</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>M</td>\n",
" <td>0.475</td>\n",
" <td>0.37</td>\n",
" <td>0.125</td>\n",
" <td>9</td>\n",
" <td>10.0580169046</td>\n",
" <td>10.0580169046</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>F</td>\n",
" <td>0.55</td>\n",
" <td>0.44</td>\n",
" <td>0.15</td>\n",
" <td>19</td>\n",
" <td>10.2825919797</td>\n",
" <td>10.2825919797</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>F</td>\n",
" <td>0.525</td>\n",
" <td>0.38</td>\n",
" <td>0.14</td>\n",
" <td>14</td>\n",
" <td>10.1755652941</td>\n",
" <td>10.1755652941</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>M</td>\n",
" <td>0.43</td>\n",
" <td>0.35</td>\n",
" <td>0.11</td>\n",
" <td>10</td>\n",
" <td>9.92374820403</td>\n",
" <td>9.92374820403</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>M</td>\n",
" <td>0.49</td>\n",
" <td>0.38</td>\n",
" <td>0.135</td>\n",
" <td>11</td>\n",
" <td>10.1066593545</td>\n",
" <td>10.1066593545</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>F</td>\n",
" <td>0.535</td>\n",
" <td>0.405</td>\n",
" <td>0.145</td>\n",
" <td>10</td>\n",
" <td>10.2224287437</td>\n",
" <td>10.2224287437</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>F</td>\n",
" <td>0.47</td>\n",
" <td>0.355</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>10.0099999803</td>\n",
" <td>10.0099999803</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>M</td>\n",
" <td>0.5</td>\n",
" <td>0.4</td>\n",
" <td>0.13</td>\n",
" <td>12</td>\n",
" <td>10.1445256879</td>\n",
" <td>10.1445256879</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>I</td>\n",
" <td>0.355</td>\n",
" <td>0.28</td>\n",
" <td>0.085</td>\n",
" <td>7</td>\n",
" <td>9.61132344417</td>\n",
" <td>9.61132344417</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>F</td>\n",
" <td>0.44</td>\n",
" <td>0.34</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>9.92513613473</td>\n",
" <td>9.92513613473</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>M</td>\n",
" <td>0.365</td>\n",
" <td>0.295</td>\n",
" <td>0.08</td>\n",
" <td>7</td>\n",
" <td>9.65721889547</td>\n",
" <td>9.65721889547</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>M</td>\n",
" <td>0.45</td>\n",
" <td>0.32</td>\n",
" <td>0.1</td>\n",
" <td>9</td>\n",
" <td>9.92028835238</td>\n",
" <td>9.92028835238</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'M', 0.455, 0.365, 0.095, 15, 9.98603457720518, 9.98603457720518),\n",
" (2, u'M', 0.35, 0.265, 0.09, 7, 9.57759650002543, 9.57759650002543),\n",
" (3, u'F', 0.53, 0.42, 0.135, 9, 10.2223107904678, 10.2223107904678),\n",
" (4, u'M', 0.44, 0.365, 0.125, 10, 9.97689991817358, 9.97689991817358),\n",
" (5, u'I', 0.33, 0.255, 0.08, 7, 9.49509568416916, 9.49509568416916),\n",
" (6, u'I', 0.425, 0.3, 0.095, 8, 9.82953290649331, 9.82953290649331),\n",
" (7, u'F', 0.53, 0.415, 0.15, 20, 10.2265451260768, 10.2265451260768),\n",
" (8, u'F', 0.545, 0.425, 0.125, 16, 10.2457829829034, 10.2457829829034),\n",
" (9, u'M', 0.475, 0.37, 0.125, 9, 10.0580169045793, 10.0580169045793),\n",
" (10, u'F', 0.55, 0.44, 0.15, 19, 10.282591979716, 10.282591979716),\n",
" (11, u'F', 0.525, 0.38, 0.14, 14, 10.1755652940929, 10.1755652940929),\n",
" (12, u'M', 0.43, 0.35, 0.11, 10, 9.92374820403203, 9.92374820403203),\n",
" (13, u'M', 0.49, 0.38, 0.135, 11, 10.1066593545246, 10.1066593545246),\n",
" (14, u'F', 0.535, 0.405, 0.145, 10, 10.2224287437228, 10.2224287437228),\n",
" (15, u'F', 0.47, 0.355, 0.1, 10, 10.0099999802503, 10.0099999802503),\n",
" (16, u'M', 0.5, 0.4, 0.13, 12, 10.1445256879087, 10.1445256879087),\n",
" (17, u'I', 0.355, 0.28, 0.085, 7, 9.61132344416669, 9.61132344416669),\n",
" (18, u'F', 0.44, 0.34, 0.1, 10, 9.92513613472998, 9.92513613472998),\n",
" (19, u'M', 0.365, 0.295, 0.08, 7, 9.65721889547397, 9.65721889547397),\n",
" (20, u'M', 0.45, 0.32, 0.1, 9, 9.9202883523785, 9.9202883523785)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_gaussian_regr;\n",
"\n",
"SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_gaussian_regr');\n",
"\n",
"SELECT * FROM abalone JOIN abalone_gaussian_regr USING (id) ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the RMS error. Note this produces a more accurate result than the linear case for this small data set:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3.75666107851</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3.75666107851157,)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_gaussian_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Cross validation\n",
"Let's run cross validation for different initial step sizes and lambda values:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[1.43663141048166, 0.69590902120895, -0.667071782085283, 0.353774339959869, 0.905828289634807, -1.35169013920124, 1.33646358226853, -1.7180201996699, 0.24481068329437, -1.16069594203275, 0.529737373827505, -1.47546535187187, 0.980691636646386, -1.57624210352963, 0.725851837240663, 0.621725612776311, -0.460138729733476, 1.62396471237433, -1.57063525564175, -1.04781069512703, -1.57213080053901, 0.415306056428304, 0.338679918908941, 1.6801320999485, 0.510208152821366, 1.56936887737261, -0.266001406651365, -1.58507523905296, -0.537812259067129, -1.5312628382147, 0.432237469443731, 1.10170029776598, -1.63469708889792, 1.3721951919766, 0.691837885175724, -1.5155853305037, -0.879575274457664, -0.378046928285422, -1.02084583660482, -0.905052544480774, -0.0895711922931681, -0.576177025418454, 1.23999477285708, -0.817661938486405, 0.522699926739854, 0.157003593221511, -0.871879320669594, -1.43641645044609, -0.804548716520958, -0.926662211340699, -0.211709523145134, -0.804204137088408, -0.616360942183748, 1.74003514481556, -1.16277414575498, -1.63929118823194, -1.82407340498873, 0.794042280007498, 1.17011142340466, 1.61469087973011, 0.206231124021691, -0.458287050727443, -0.86716709482106, -0.308198290202033, 1.58323228771844, -0.827377112621467, -1.47575759821081, -0.445203431415296, 1.16551153170496, 1.37211944359085, 1.57289635294413, -0.666315273550231, -0.476600127294968, 0.296662747500788, 1.15584301004967, 1.46263374958392, 0.799907772688504, 1.09708642804713, -1.39887637044597, 0.190356100346786, 1.28636855084754, 1.42039838933069, -1.03949687018857, -0.555861325085124, 0.341382237642365, -0.899593443791824, 1.61885754858329, 1.57586783372196, -0.71510056870506, -1.59701588256911, -1.56188454682269, -0.174291105678036, 1.19615194057151, -1.37224827664802, 1.59933543202374, 0.435995073036911, 1.38582746743955, 0.449261405728548, -1.61911544479704, -1.7689529702645]</td>\n",
" <td>2.60165100186</td>\n",
" <td>1.04370217846</td>\n",
" <td>165</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>[None]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([1.43663141048166, 0.69590902120895, -0.667071782085283, 0.353774339959869, 0.905828289634807, -1.35169013920124, 1.33646358226853, -1.7180201996699, 0.24481068329437, -1.16069594203275, 0.529737373827505, -1.47546535187187, 0.980691636646386, -1.57624210352963, 0.725851837240663, 0.621725612776311, -0.460138729733476, 1.62396471237433, -1.57063525564175, -1.04781069512703, -1.57213080053901, 0.415306056428304, 0.338679918908941, 1.6801320999485, 0.510208152821366, 1.56936887737261, -0.266001406651365, -1.58507523905296, -0.537812259067129, -1.5312628382147, 0.432237469443731, 1.10170029776598, -1.63469708889792, 1.3721951919766, 0.691837885175724, -1.5155853305037, -0.879575274457664, -0.378046928285422, -1.02084583660482, -0.905052544480774, -0.0895711922931681, -0.576177025418454, 1.23999477285708, -0.817661938486405, 0.522699926739854, 0.157003593221511, -0.871879320669594, -1.43641645044609, -0.804548716520958, -0.926662211340699, -0.211709523145134, -0.804204137088408, -0.616360942183748, 1.74003514481556, -1.16277414575498, -1.63929118823194, -1.82407340498873, 0.794042280007498, 1.17011142340466, 1.61469087973011, 0.206231124021691, -0.458287050727443, -0.86716709482106, -0.308198290202033, 1.58323228771844, -0.827377112621467, -1.47575759821081, -0.445203431415296, 1.16551153170496, 1.37211944359085, 1.57289635294413, -0.666315273550231, -0.476600127294968, 0.296662747500788, 1.15584301004967, 1.46263374958392, 0.799907772688504, 1.09708642804713, -1.39887637044597, 0.190356100346786, 1.28636855084754, 1.42039838933069, -1.03949687018857, -0.555861325085124, 0.341382237642365, -0.899593443791824, 1.61885754858329, 1.57586783372196, -0.71510056870506, -1.59701588256911, -1.56188454682269, -0.174291105678036, 1.19615194057151, -1.37224827664802, 1.59933543202374, 0.435995073036911, 1.38582746743955, 0.449261405728548, -1.61911544479704, -1.7689529702645], 2.60165100186252, 1.04370217845818, 165, 20L, 0L, [None])]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_svm_gaussian_regression, abalone_svm_gaussian_regression_summary, \n",
"abalone_svm_gaussian_regression_random, abalone_svm_gaussian_regression_cv;\n",
"\n",
"SELECT madlib.svm_regression( 'abalone',\n",
" 'abalone_svm_gaussian_regression',\n",
" 'rings',\n",
" 'ARRAY[1, length, diameter, height]',\n",
" 'gaussian',\n",
" 'n_components=100',\n",
" '',\n",
" 'init_stepsize=[0.01,1], n_folds=3, max_iter=200, lambda=[0.01, 0.1, 0.5], validation_result=abalone_svm_gaussian_regression_cv'\n",
" );\n",
"\n",
"SELECT * FROM abalone_svm_gaussian_regression;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the summary table showing the final model parameters are those that produced \n",
"the lowest error in the cross validation runs:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>method</th>\n",
" <th>version_number</th>\n",
" <th>source_table</th>\n",
" <th>model_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>kernel_func</th>\n",
" <th>kernel_params</th>\n",
" <th>grouping_col</th>\n",
" <th>optim_params</th>\n",
" <th>reg_params</th>\n",
" <th>num_all_groups</th>\n",
" <th>num_failed_groups</th>\n",
" <th>total_rows_processed</th>\n",
" <th>total_rows_skipped</th>\n",
" </tr>\n",
" <tr>\n",
" <td>SVR</td>\n",
" <td>1.17-dev</td>\n",
" <td>abalone</td>\n",
" <td>abalone_svm_gaussian_regression</td>\n",
" <td>rings</td>\n",
" <td>ARRAY[1, length, diameter, height]</td>\n",
" <td>gaussian</td>\n",
" <td>gamma=0.25, n_components=100,random_state=1, fit_intercept=False, fit_in_memory=True</td>\n",
" <td>NULL</td>\n",
" <td> init_stepsize=1.0,<br> decay_factor=0.9,<br> max_iter=200,<br> tolerance=1e-10,<br> epsilon=0.01,<br> eps_table=,<br> class_weight=<br> </td>\n",
" <td>lambda=0.01, norm=l2, n_folds=3</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'SVR', u'1.17-dev', u'abalone', u'abalone_svm_gaussian_regression', u'rings', u'ARRAY[1, length, diameter, height]', u'gaussian', u'gamma=0.25, n_components=100,random_state=1, fit_intercept=False, fit_in_memory=True', u'NULL', u' init_stepsize=1.0,\\n decay_factor=0.9,\\n max_iter=200,\\n tolerance=1e-10,\\n epsilon=0.01,\\n eps_table=,\\n class_weight=\\n ', u'lambda=0.01, norm=l2, n_folds=3', 1, 0, 20L, 0L)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql SELECT * FROM abalone_svm_gaussian_regression_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the values for cross validation:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>init_stepsize</th>\n",
" <th>lambda</th>\n",
" <th>mean_score</th>\n",
" <th>std_dev_score</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.01</td>\n",
" <td>-10.6826786731</td>\n",
" <td>2.89736952916</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.5</td>\n",
" <td>-10.6896583532</td>\n",
" <td>2.89700539715</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.1</td>\n",
" <td>-3.81154741847</td>\n",
" <td>2.20160973525</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.01</td>\n",
" <td>0.1</td>\n",
" <td>-10.6839690886</td>\n",
" <td>2.89730222282</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.01</td>\n",
" <td>-3.62237265846</td>\n",
" <td>2.30389631902</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1.0</td>\n",
" <td>0.5</td>\n",
" <td>-3.9659217506</td>\n",
" <td>2.24231164049</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('0.01'), Decimal('0.01'), Decimal('-10.6826786731'), Decimal('2.89736952916')),\n",
" (Decimal('0.01'), Decimal('0.5'), Decimal('-10.6896583532'), Decimal('2.89700539715')),\n",
" (Decimal('1.0'), Decimal('0.1'), Decimal('-3.81154741847'), Decimal('2.20160973525')),\n",
" (Decimal('0.01'), Decimal('0.1'), Decimal('-10.6839690886'), Decimal('2.89730222282')),\n",
" (Decimal('1.0'), Decimal('0.01'), Decimal('-3.62237265846'), Decimal('2.30389631902')),\n",
" (Decimal('1.0'), Decimal('0.5'), Decimal('-3.9659217506'), Decimal('2.24231164049'))]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM abalone_svm_gaussian_regression_cv;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7. Predict using cross-validated Gaussian regression model:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>svm_predict</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS abalone_gaussian_regr;\n",
"SELECT madlib.svm_predict('abalone_svm_gaussian_regression', \n",
" 'abalone', \n",
" 'id', \n",
" 'abalone_gaussian_regr');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the RMS error. Note this produces a more accurate result than the previous run with the Gaussian kernel:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>rms_error</th>\n",
" </tr>\n",
" <tr>\n",
" <td>3.7567001982</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(3.7567001982019,)]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT SQRT(AVG((rings-prediction)*(rings-prediction))) as rms_error FROM abalone \n",
"JOIN abalone_gaussian_regr USING (id);"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Novelty detection \n",
"# 1. Train a non-linear one-class SVM\n",
"Use a Gaussian kernel using the housing data set. Note that the dependent variable is not a parameter for one-class:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>coef</th>\n",
" <th>loss</th>\n",
" <th>norm_of_gradient</th>\n",
" <th>num_iterations</th>\n",
" <th>num_rows_processed</th>\n",
" <th>num_rows_skipped</th>\n",
" <th>dep_var_mapping</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, -0.0549714494076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694]</td>\n",
" <td>0.944029805394</td>\n",
" <td>14.5264894419</td>\n",
" <td>100</td>\n",
" <td>16</td>\n",
" <td>-1</td>\n",
" <td>[-1.0, 1.0]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([0.0207901288823711, -0.00103437489314969, 0.00407820868429805, 0.0274910360546609, 0.0105696547048294, -0.00313332466259033, -0.0216703145014011, 0.0363248037825208, -0.0211400498166549, -0.00827402232219555, 0.0265909439934851, 0.0282462482323058, -0.0407407195393746, 0.0191290942177852, -0.00313542082923064, -0.0191740603622109, 0.0143626646548982, -0.0620527674181034, -0.000319831622794402, 0.00388104709972051, 0.00248129433065678, 0.00764915273571186, 0.014492283562898, 0.0184730815984353, -0.00745840880633255, -0.0232208663374367, -0.010724056217189, 0.00541494627043399, 0.0150679846777238, 0.0204022414812525, -0.0294626167089617, -0.00399506510201406, -0.0231139983460727, 0.0242203153309423, -0.0421196963278802, 0.0112202149916885, -0.00720876723524249, 0.0213674589734111, -0.00260107056222295, -0.0130652059444514, 0.0710580616012718, 0.0519822855717347, 0.00961050532247376, 0.0390561950837254, -0.0152620688050253, 0.0100336750737295, 0.0632488712630204, -0.0549714494076944, -0.007684860916257, 0.0322104572263339, -0.00832311210931705, 0.0279669244721609, 0.0455147539995411, -0.0639670005155479, -0.00965055072583972, 0.00648588125681694], 0.944029805394336, 14.5264894418914, 100, 16L, -1L, [-1.0, 1.0])]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_one_class_gaussian, houses_one_class_gaussian_summary, houses_one_class_gaussian_random;\n",
"\n",
"SELECT madlib.svm_one_class('houses',\n",
" 'houses_one_class_gaussian',\n",
" 'ARRAY[1,tax,bedroom,bath,size,lot,price]',\n",
" 'gaussian',\n",
" 'gamma=0.5,n_components=55, random_state=3',\n",
" NULL,\n",
" 'max_iter=100, init_stepsize=10,lambda=10, tolerance=0'\n",
" );\n",
"\n",
"SELECT * FROM houses_one_class_gaussian;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Create test data\n",
"For the novelty detection using one-class, let's create a test data set using the last 3 values from the training set plus an outlier at the end (10x price):"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"4 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>650000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 3100, 3, 2.0, 140000, 1760, 38000),\n",
" (2, 2070, 2, 3.0, 148000, 1550, 14000),\n",
" (3, 650, 3, 1.5, 65000, 1450, 12000),\n",
" (4, 650, 3, 1.5, 650000, 1450, 12000)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_one_class_test;\n",
"\n",
"CREATE TABLE houses_one_class_test (id INT, tax INT, bedroom INT, bath FLOAT, price INT,\n",
" size INT, lot INT);\n",
"\n",
"INSERT INTO houses_one_class_test VALUES \n",
" (1 , 3100 , 3 , 2 , 140000 , 1760 , 38000),\n",
" (2 , 2070 , 2 , 3 , 148000 , 1550 , 14000),\n",
" (3 , 650 , 3 , 1.5 , 65000 , 1450 , 12000),\n",
" (4 , 650 , 3 , 1.5 , 650000 , 1450 , 12000);\n",
" \n",
"SELECT * FROM houses_one_class_test ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Predict using Gaussian one-class novelty detection model\n",
"Result shows the last row predicted to be novel:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>price</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>prediction</th>\n",
" <th>decision_function</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>140000</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>1.0</td>\n",
" <td>0.0662278474212</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>148000</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>1.0</td>\n",
" <td>0.092124936453</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>65000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>1.0</td>\n",
" <td>0.03415206006</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>650000</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>-1.0</td>\n",
" <td>-0.0131918729845</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 3100, 3, 2.0, 140000, 1760, 38000, 1.0, 0.066227847421185),\n",
" (2, 2070, 2, 3.0, 148000, 1550, 14000, 1.0, 0.0921249364529948),\n",
" (3, 650, 3, 1.5, 65000, 1450, 12000, 1.0, 0.0341520600599523),\n",
" (4, 650, 3, 1.5, 650000, 1450, 12000, -1.0, -0.0131918729845241)]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS houses_pred;\n",
"\n",
"SELECT madlib.svm_predict('houses_one_class_gaussian', \n",
" 'houses_one_class_test', \n",
" 'id', \n",
" 'houses_pred');\n",
"\n",
"SELECT * FROM houses_one_class_test JOIN houses_pred USING (id) ORDER BY id;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}