blob: 69688bf3a9bd5ba4a2dfab544f3492bdfcf131d1 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# K-means clustering\n",
"This module can compute clusters given the number of centroids k as an input, using a variety of seeding methods. It can also automatically select the best k value from a range of suggested k values, using the simplified silhouette method or the elbow method.\n",
"\n",
"## Table of contents\n",
"\n",
"<a href=\"#setup\">0. Setup</a>\n",
"\n",
"<a href=\"#single_k\">1. Clustering for single k value</a>\n",
"\n",
"<a href=\"#range_k\">2. Clustering for a range of k values</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"setup\"></a>\n",
"# 0. Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
"#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
"\n",
"# Greenplum Database 5.x on GCP - via tunnel\n",
"%sql postgresql://gpadmin@localhost:8000/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"single_k\"></a>\n",
"# Clustering for single k value\n",
"\n",
"# 1. Input data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"10 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>points</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>[13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]),\n",
" (2, [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]),\n",
" (3, [13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]),\n",
" (4, [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]),\n",
" (5, [13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]),\n",
" (6, [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]),\n",
" (7, [14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]),\n",
" (8, [14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]),\n",
" (9, [14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]),\n",
" (10, [13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0])]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_sample;\n",
"\n",
"CREATE TABLE km_sample(pid int, points double precision[]);\n",
"\n",
"INSERT INTO km_sample VALUES\n",
"(1, '{14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065}'),\n",
"(2, '{13.2, 1.78, 2.14, 11.2, 1, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050}'),\n",
"(3, '{13.16, 2.36, 2.67, 18.6, 101, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185}'),\n",
"(4, '{14.37, 1.95, 2.5, 16.8, 113, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480}'),\n",
"(5, '{13.24, 2.59, 2.87, 21, 118, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735}'),\n",
"(6, '{14.2, 1.76, 2.45, 15.2, 112, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450}'),\n",
"(7, '{14.39, 1.87, 2.45, 14.6, 96, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290}'),\n",
"(8, '{14.06, 2.15, 2.61, 17.6, 121, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295}'),\n",
"(9, '{14.83, 1.64, 2.17, 14, 97, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045}'),\n",
"(10, '{13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, 3.55, 1045}');\n",
"\n",
"SELECT * FROM km_sample ORDER BY pid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Run k-means clustering using kmeans++ with centroid seeding\n",
"Use squared Euclidean distance which is a commonly used distance function."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>centroids</th>\n",
" <th>cluster_variance</th>\n",
" <th>objective_fn</th>\n",
" <th>frac_reassigned</th>\n",
" <th>num_iterations</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]]</td>\n",
" <td>[90512.324426408, 60672.638245208]</td>\n",
" <td>151184.962672</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]], [90512.324426408, 60672.638245208], 151184.962671616, 0.0, 2)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_result;\n",
"\n",
"CREATE TABLE km_result AS\n",
"SELECT * FROM madlib.kmeanspp( 'km_sample', -- Table of source data\n",
" 'points', -- Column containing point co-ordinates \n",
" 2, -- Number of centroids to calculate\n",
" 'madlib.squared_dist_norm2', -- Distance function\n",
" 'madlib.avg', -- Aggregate function\n",
" 20, -- Number of iterations\n",
" 0.001 -- Fraction of centroids reassigned to keep iterating \n",
" );\n",
"\n",
"SELECT * FROM km_result;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Simplified silhouette coefficient\n",
"Average for whole data set. Make sure to use the same distance function as k-means above."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>simple_silhouette</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.872087020147</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.872087020146542,)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM madlib.simple_silhouette( 'km_sample', -- Input points table\n",
" 'points', -- Column containing points\n",
" (SELECT centroids FROM km_result), -- Centroids\n",
" 'madlib.squared_dist_norm2' -- Distance function\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now calculate simplified silhouette coefficient for each point in the data set:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>centroid_id</th>\n",
" <th>neighbor_centroid_id</th>\n",
" <th>silh</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.902123603766</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.88017393665</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.382089480836</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.919141622264</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.822664572979</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.943394365443</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.970809939945</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.977109993192</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.961796151461</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.961566534928</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 0, 1, 0.902123603765829),\n",
" (2, 0, 1, 0.880173936650081),\n",
" (3, 1, 0, 0.382089480836045),\n",
" (4, 1, 0, 0.919141622264229),\n",
" (5, 0, 1, 0.822664572979012),\n",
" (6, 1, 0, 0.943394365442933),\n",
" (7, 1, 0, 0.97080993994542),\n",
" (8, 1, 0, 0.977109993192101),\n",
" (9, 0, 1, 0.961796151461416),\n",
" (10, 0, 1, 0.961566534928355)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_points_silh;\n",
"\n",
"SELECT * FROM madlib.simple_silhouette_points( 'km_sample', -- Input points table\n",
" 'km_points_silh', -- Output table\n",
" 'pid', -- Point ID column in input table\n",
" 'points', -- Points column in input table\n",
" 'km_result', -- Centroids table\n",
" 'centroids', -- Column in centroids table containing centroids \n",
" 'madlib.squared_dist_norm2' -- Distance function\n",
" );\n",
"\n",
"SELECT * FROM km_points_silh ORDER BY pid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Cluster assignment for each point\n",
"Use the closest_column() function to map each point to the cluster that it belongs to. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"10 rows affected.\n",
"Done.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>points</th>\n",
" <th>cluster_id</th>\n",
" <th>distance</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]</td>\n",
" <td>0</td>\n",
" <td>7435.21009152</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]</td>\n",
" <td>0</td>\n",
" <td>11468.6944211</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]</td>\n",
" <td>1</td>\n",
" <td>24088.7121562</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]</td>\n",
" <td>1</td>\n",
" <td>19623.3917358</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]</td>\n",
" <td>0</td>\n",
" <td>64929.1866187</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]</td>\n",
" <td>1</td>\n",
" <td>12114.3606538</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]</td>\n",
" <td>1</td>\n",
" <td>2664.07117376</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]</td>\n",
" <td>1</td>\n",
" <td>2182.10252576</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]</td>\n",
" <td>0</td>\n",
" <td>3330.15771392</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>[13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0]</td>\n",
" <td>0</td>\n",
" <td>3349.07558113</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0], 0, 7435.2100915204),\n",
" (2, [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], 0, 11468.6944211204),\n",
" (3, [13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], 1, 24088.7121561664),\n",
" (4, [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], 1, 19623.3917357604),\n",
" (5, [13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], 0, 64929.1866187204),\n",
" (6, [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], 1, 12114.3606537604),\n",
" (7, [14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0], 1, 2664.0711737604),\n",
" (8, [14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0], 1, 2182.1025257604),\n",
" (9, [14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0], 0, 3330.1577139204),\n",
" (10, [13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0], 0, 3349.0755811264)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS point_cluster_map;\n",
"\n",
"CREATE TABLE point_cluster_map AS \n",
"SELECT data.*, (madlib.closest_column(centroids, points, 'madlib.squared_dist_norm2')).*\n",
"FROM km_sample as data, km_result;\n",
"\n",
"ALTER TABLE point_cluster_map RENAME column_id to cluster_id; -- change column name\n",
"SELECT * FROM point_cluster_map ORDER BY pid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Display cluster ID\n",
"There are two steps to get the cluster id associated with the centroid coordinates, if you need it. First unnest the cluster centroids 2-D array to get a set of 1-D centroid arrays:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"2 rows affected.\n",
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>unnest_row_id</th>\n",
" <th>unnest_result</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]),\n",
" (2, [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0])]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_centroids_unnest;\n",
"\n",
"-- Run unnest function\n",
"CREATE TABLE km_centroids_unnest AS\n",
"SELECT (madlib.array_unnest_2d_to_1d(centroids)).*\n",
"FROM km_result;\n",
"\n",
"SELECT * FROM km_centroids_unnest ORDER BY 1;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the ID column returned by 'array_unnest_2d_to_1d()' is just a row ID and not the cluster ID assigned by k-means. The second step to get the cluster_id is:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>unnest_row_id</th>\n",
" <th>unnest_result</th>\n",
" <th>cluster_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], 0),\n",
" (2, [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0], 1)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT cent.*, (madlib.closest_column(centroids, unnest_result, 'madlib.squared_dist_norm2')).column_id as cluster_id\n",
"FROM km_centroids_unnest as cent, km_result\n",
"ORDER BY cent.unnest_row_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Array input\n",
"Create the input table:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"10 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>p1</th>\n",
" <th>p2</th>\n",
" <th>p3</th>\n",
" <th>p4</th>\n",
" <th>p5</th>\n",
" <th>p6</th>\n",
" <th>p7</th>\n",
" <th>p8</th>\n",
" <th>p9</th>\n",
" <th>p10</th>\n",
" <th>p11</th>\n",
" <th>p12</th>\n",
" <th>p13</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>14.23</td>\n",
" <td>1.71</td>\n",
" <td>2.43</td>\n",
" <td>15.6</td>\n",
" <td>127.0</td>\n",
" <td>2.8</td>\n",
" <td>3.06</td>\n",
" <td>0.28</td>\n",
" <td>2.29</td>\n",
" <td>5.64</td>\n",
" <td>1.04</td>\n",
" <td>3.92</td>\n",
" <td>1065.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>13.2</td>\n",
" <td>1.78</td>\n",
" <td>2.14</td>\n",
" <td>11.2</td>\n",
" <td>1.0</td>\n",
" <td>2.65</td>\n",
" <td>2.76</td>\n",
" <td>0.26</td>\n",
" <td>1.28</td>\n",
" <td>4.38</td>\n",
" <td>1.05</td>\n",
" <td>3.49</td>\n",
" <td>1050.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>13.16</td>\n",
" <td>2.36</td>\n",
" <td>2.67</td>\n",
" <td>18.6</td>\n",
" <td>101.0</td>\n",
" <td>2.8</td>\n",
" <td>3.24</td>\n",
" <td>0.3</td>\n",
" <td>2.81</td>\n",
" <td>5.6799</td>\n",
" <td>1.03</td>\n",
" <td>3.17</td>\n",
" <td>1185.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>14.37</td>\n",
" <td>1.95</td>\n",
" <td>2.5</td>\n",
" <td>16.8</td>\n",
" <td>113.0</td>\n",
" <td>3.85</td>\n",
" <td>3.49</td>\n",
" <td>0.24</td>\n",
" <td>2.18</td>\n",
" <td>7.8</td>\n",
" <td>0.86</td>\n",
" <td>3.45</td>\n",
" <td>1480.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>13.24</td>\n",
" <td>2.59</td>\n",
" <td>2.87</td>\n",
" <td>21.0</td>\n",
" <td>118.0</td>\n",
" <td>2.8</td>\n",
" <td>2.69</td>\n",
" <td>0.39</td>\n",
" <td>1.82</td>\n",
" <td>4.32</td>\n",
" <td>1.04</td>\n",
" <td>2.93</td>\n",
" <td>735.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>14.2</td>\n",
" <td>1.76</td>\n",
" <td>2.45</td>\n",
" <td>15.2</td>\n",
" <td>112.0</td>\n",
" <td>3.27</td>\n",
" <td>3.39</td>\n",
" <td>0.34</td>\n",
" <td>1.97</td>\n",
" <td>6.75</td>\n",
" <td>1.05</td>\n",
" <td>2.85</td>\n",
" <td>1450.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>14.39</td>\n",
" <td>1.87</td>\n",
" <td>2.45</td>\n",
" <td>14.6</td>\n",
" <td>96.0</td>\n",
" <td>2.5</td>\n",
" <td>2.52</td>\n",
" <td>0.3</td>\n",
" <td>1.98</td>\n",
" <td>5.25</td>\n",
" <td>1.02</td>\n",
" <td>3.58</td>\n",
" <td>1290.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>14.06</td>\n",
" <td>2.15</td>\n",
" <td>2.61</td>\n",
" <td>17.6</td>\n",
" <td>121.0</td>\n",
" <td>2.6</td>\n",
" <td>2.51</td>\n",
" <td>0.31</td>\n",
" <td>1.25</td>\n",
" <td>5.05</td>\n",
" <td>1.06</td>\n",
" <td>3.58</td>\n",
" <td>1295.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>14.83</td>\n",
" <td>1.64</td>\n",
" <td>2.17</td>\n",
" <td>14.0</td>\n",
" <td>97.0</td>\n",
" <td>2.8</td>\n",
" <td>2.98</td>\n",
" <td>0.29</td>\n",
" <td>1.98</td>\n",
" <td>5.2</td>\n",
" <td>1.08</td>\n",
" <td>2.85</td>\n",
" <td>1045.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>13.86</td>\n",
" <td>1.35</td>\n",
" <td>2.27</td>\n",
" <td>16.0</td>\n",
" <td>98.0</td>\n",
" <td>2.98</td>\n",
" <td>3.15</td>\n",
" <td>0.22</td>\n",
" <td>1.85</td>\n",
" <td>7.2199</td>\n",
" <td>1.01</td>\n",
" <td>3.55</td>\n",
" <td>1045.0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0),\n",
" (2, 13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0),\n",
" (3, 13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0),\n",
" (4, 14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0),\n",
" (5, 13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0),\n",
" (6, 14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0),\n",
" (7, 14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0),\n",
" (8, 14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0),\n",
" (9, 14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0),\n",
" (10, 13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_arrayin CASCADE;\n",
"\n",
"CREATE TABLE km_arrayin(pid int, \n",
" p1 float, \n",
" p2 float, \n",
" p3 float,\n",
" p4 float, \n",
" p5 float, \n",
" p6 float,\n",
" p7 float, \n",
" p8 float, \n",
" p9 float,\n",
" p10 float, \n",
" p11 float, \n",
" p12 float,\n",
" p13 float);\n",
"\n",
"INSERT INTO km_arrayin VALUES\n",
"(1, 14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065),\n",
"(2, 13.2, 1.78, 2.14, 11.2, 1, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050),\n",
"(3, 13.16, 2.36, 2.67, 18.6, 101, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185),\n",
"(4, 14.37, 1.95, 2.5, 16.8, 113, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480),\n",
"(5, 13.24, 2.59, 2.87, 21, 118, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735),\n",
"(6, 14.2, 1.76, 2.45, 15.2, 112, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450),\n",
"(7, 14.39, 1.87, 2.45, 14.6, 96, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290),\n",
"(8, 14.06, 2.15, 2.61, 17.6, 121, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295),\n",
"(9, 14.83, 1.64, 2.17, 14, 97, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045),\n",
"(10, 13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, 3.55, 1045);\n",
"\n",
"SELECT * FROM km_arrayin ORDER BY pid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now find the cluster assignment for each point using random seeding:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>p1</th>\n",
" <th>p2</th>\n",
" <th>p3</th>\n",
" <th>p4</th>\n",
" <th>p5</th>\n",
" <th>p6</th>\n",
" <th>p7</th>\n",
" <th>p8</th>\n",
" <th>p9</th>\n",
" <th>p10</th>\n",
" <th>p11</th>\n",
" <th>p12</th>\n",
" <th>p13</th>\n",
" <th>cluster_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>14.23</td>\n",
" <td>1.71</td>\n",
" <td>2.43</td>\n",
" <td>15.6</td>\n",
" <td>127.0</td>\n",
" <td>2.8</td>\n",
" <td>3.06</td>\n",
" <td>0.28</td>\n",
" <td>2.29</td>\n",
" <td>5.64</td>\n",
" <td>1.04</td>\n",
" <td>3.92</td>\n",
" <td>1065.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>13.2</td>\n",
" <td>1.78</td>\n",
" <td>2.14</td>\n",
" <td>11.2</td>\n",
" <td>1.0</td>\n",
" <td>2.65</td>\n",
" <td>2.76</td>\n",
" <td>0.26</td>\n",
" <td>1.28</td>\n",
" <td>4.38</td>\n",
" <td>1.05</td>\n",
" <td>3.49</td>\n",
" <td>1050.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>13.16</td>\n",
" <td>2.36</td>\n",
" <td>2.67</td>\n",
" <td>18.6</td>\n",
" <td>101.0</td>\n",
" <td>2.8</td>\n",
" <td>3.24</td>\n",
" <td>0.3</td>\n",
" <td>2.81</td>\n",
" <td>5.6799</td>\n",
" <td>1.03</td>\n",
" <td>3.17</td>\n",
" <td>1185.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>14.37</td>\n",
" <td>1.95</td>\n",
" <td>2.5</td>\n",
" <td>16.8</td>\n",
" <td>113.0</td>\n",
" <td>3.85</td>\n",
" <td>3.49</td>\n",
" <td>0.24</td>\n",
" <td>2.18</td>\n",
" <td>7.8</td>\n",
" <td>0.86</td>\n",
" <td>3.45</td>\n",
" <td>1480.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>13.24</td>\n",
" <td>2.59</td>\n",
" <td>2.87</td>\n",
" <td>21.0</td>\n",
" <td>118.0</td>\n",
" <td>2.8</td>\n",
" <td>2.69</td>\n",
" <td>0.39</td>\n",
" <td>1.82</td>\n",
" <td>4.32</td>\n",
" <td>1.04</td>\n",
" <td>2.93</td>\n",
" <td>735.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>14.2</td>\n",
" <td>1.76</td>\n",
" <td>2.45</td>\n",
" <td>15.2</td>\n",
" <td>112.0</td>\n",
" <td>3.27</td>\n",
" <td>3.39</td>\n",
" <td>0.34</td>\n",
" <td>1.97</td>\n",
" <td>6.75</td>\n",
" <td>1.05</td>\n",
" <td>2.85</td>\n",
" <td>1450.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>14.39</td>\n",
" <td>1.87</td>\n",
" <td>2.45</td>\n",
" <td>14.6</td>\n",
" <td>96.0</td>\n",
" <td>2.5</td>\n",
" <td>2.52</td>\n",
" <td>0.3</td>\n",
" <td>1.98</td>\n",
" <td>5.25</td>\n",
" <td>1.02</td>\n",
" <td>3.58</td>\n",
" <td>1290.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>14.06</td>\n",
" <td>2.15</td>\n",
" <td>2.61</td>\n",
" <td>17.6</td>\n",
" <td>121.0</td>\n",
" <td>2.6</td>\n",
" <td>2.51</td>\n",
" <td>0.31</td>\n",
" <td>1.25</td>\n",
" <td>5.05</td>\n",
" <td>1.06</td>\n",
" <td>3.58</td>\n",
" <td>1295.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>14.83</td>\n",
" <td>1.64</td>\n",
" <td>2.17</td>\n",
" <td>14.0</td>\n",
" <td>97.0</td>\n",
" <td>2.8</td>\n",
" <td>2.98</td>\n",
" <td>0.29</td>\n",
" <td>1.98</td>\n",
" <td>5.2</td>\n",
" <td>1.08</td>\n",
" <td>2.85</td>\n",
" <td>1045.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>13.86</td>\n",
" <td>1.35</td>\n",
" <td>2.27</td>\n",
" <td>16.0</td>\n",
" <td>98.0</td>\n",
" <td>2.98</td>\n",
" <td>3.15</td>\n",
" <td>0.22</td>\n",
" <td>1.85</td>\n",
" <td>7.2199</td>\n",
" <td>1.01</td>\n",
" <td>3.55</td>\n",
" <td>1045.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0, 1),\n",
" (2, 13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0, 1),\n",
" (3, 13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0, 1),\n",
" (4, 14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0, 0),\n",
" (5, 13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0, 1),\n",
" (6, 14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0, 0),\n",
" (7, 14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0, 0),\n",
" (8, 14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0, 0),\n",
" (9, 14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0, 1),\n",
" (10, 13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0, 1)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_result_array;\n",
"\n",
"-- Run kmeans algorithm\n",
"CREATE TABLE km_result_array AS\n",
"SELECT * FROM madlib.kmeans_random('km_arrayin', -- Table of source data\n",
" 'ARRAY[p1, p2, p3, p4, p5, p6, -- Points\n",
" p7, p8, p9, p10, p11, p12, p13]', \n",
" 2, -- Number of centroids to calculate\n",
" 'madlib.squared_dist_norm2', -- Distance function\n",
" 'madlib.avg', -- Aggregate function\n",
" 20, -- Number of iterations\n",
" 0.001); -- Fraction of centroids reassigned to keep iterating \n",
"\n",
"-- Get point assignment\n",
"SELECT data.*, (madlib.closest_column(centroids, \n",
" ARRAY[p1, p2, p3, p4, p5, p6, \n",
" p7, p8, p9, p10, p11, p12, p13], \n",
" 'madlib.squared_dist_norm2')).column_id as cluster_id\n",
"FROM km_arrayin as data, km_result_array\n",
"ORDER BY data.pid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"range_k\"></a>\n",
"# Clustering for a range of k values\n",
"\n",
"# 1. Auto k selection\n",
"Now let's run k-means random for a variety of k values and compare using simple silhouette and elbow methods."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>k</th>\n",
" <th>centroids</th>\n",
" <th>cluster_variance</th>\n",
" <th>objective_fn</th>\n",
" <th>frac_reassigned</th>\n",
" <th>num_iterations</th>\n",
" <th>silhouette</th>\n",
" <th>elbow</th>\n",
" <th>selection_algorithm</th>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]]</td>\n",
" <td>[0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633]</td>\n",
" <td>9384.14038935</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" <td>0.953601354123</td>\n",
" <td>44857.4267179</td>\n",
" <td>silhouette</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(5, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]], [0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633], 9384.14038934667, 0.0, 3, 0.953601354123412, 44857.4267178967, u'silhouette')]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_result_auto, km_result_auto_summary;\n",
"\n",
"SELECT madlib.kmeans_random_auto(\n",
" 'km_sample', -- points table\n",
" 'km_result_auto', -- output table\n",
" 'points', -- column name in point table\n",
" ARRAY[2,3,4,5,6], -- k values to try\n",
" 'madlib.squared_dist_norm2', -- distance function\n",
" 'madlib.avg', -- aggregate function\n",
" 20, -- max iterations\n",
" 0.001, -- minimum fraction of centroids reassigned to continue iterating\n",
" 'both' -- k selection algorithm (simple silhouette and elbow)\n",
");\n",
"\n",
"SELECT * FROM km_result_auto_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The best selection above is made by the silhouette algorithm by default. Note that the elbow method may select a different k value as the best. To see results for all k values:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>k</th>\n",
" <th>centroids</th>\n",
" <th>cluster_variance</th>\n",
" <th>objective_fn</th>\n",
" <th>frac_reassigned</th>\n",
" <th>num_iterations</th>\n",
" <th>silhouette</th>\n",
" <th>elbow</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[[13.7533333333333, 1.905, 2.425, 16.0666666666667, 90.3333333333333, 2.805, 2.98, 0.29, 2.005, 5.40663333333333, 1.04166666666667, 3.31833333333333, 1020.83333333333], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75]]</td>\n",
" <td>[122999.110416013, 30561.74805]</td>\n",
" <td>153560.858466</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" <td>0.86817460894</td>\n",
" <td>71506.2870379</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75], [13.856, 1.768, 2.336, 15.08, 84.8, 2.806, 3.038, 0.27, 2.042, 5.62396, 1.042, 3.396, 1078.0]]</td>\n",
" <td>[0.0, 30561.74805, 24007.669589612]</td>\n",
" <td>54569.4176396</td>\n",
" <td>0.0</td>\n",
" <td>4</td>\n",
" <td>0.914703320636</td>\n",
" <td>38199.4011006</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[[13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]]</td>\n",
" <td>[8078.22646267333, 0.0, 0.0, 90512.324426408]</td>\n",
" <td>98590.5508891</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" <td>0.894959666277</td>\n",
" <td>8221.52797196</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]]</td>\n",
" <td>[0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633]</td>\n",
" <td>9384.14038935</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" <td>0.953601354123</td>\n",
" <td>44857.4267179</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [14.04, 1.8225, 2.435, 16.65, 110.0, 2.845, 2.97, 0.295, 1.985, 5.594975, 1.0425, 3.3125, 972.5], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.225, 2.01, 2.53, 16.1, 108.5, 2.55, 2.515, 0.305, 1.615, 5.15, 1.04, 3.58, 1292.5], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]]</td>\n",
" <td>[0.0, 0.0, 76176.4564000075, 0.0, 329.8988, 0.0]</td>\n",
" <td>76506.3552</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" <td>0.772762876324</td>\n",
" <td>78164.3126552</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2, [[13.7533333333333, 1.905, 2.425, 16.0666666666667, 90.3333333333333, 2.805, 2.98, 0.29, 2.005, 5.40663333333333, 1.04166666666667, 3.31833333333333, 1020.83333333333], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75]], [122999.110416013, 30561.74805], 153560.858466013, 0.0, 2, 0.868174608939623, 71506.2870379353),\n",
" (3, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75], [13.856, 1.768, 2.336, 15.08, 84.8, 2.806, 3.038, 0.27, 2.042, 5.62396, 1.042, 3.396, 1078.0]], [0.0, 30561.74805, 24007.669589612], 54569.417639612, 0.0, 4, 0.914703320635945, 38199.4011006343),\n",
" (4, [[13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]], [8078.22646267333, 0.0, 0.0, 90512.324426408], 98590.5508890814, 0.0, 3, 0.894959666276704, 8221.52797196453),\n",
" (5, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]], [0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633], 9384.14038934667, 0.0, 3, 0.953601354123412, 44857.4267178967),\n",
" (6, [[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [14.04, 1.8225, 2.435, 16.65, 110.0, 2.845, 2.97, 0.295, 1.985, 5.594975, 1.0425, 3.3125, 972.5], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.225, 2.01, 2.53, 16.1, 108.5, 2.55, 2.515, 0.305, 1.615, 5.15, 1.04, 3.58, 1292.5], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]], [0.0, 0.0, 76176.4564000075, 0.0, 329.8988, 0.0], 76506.3552000075, 0.0, 2, 0.772762876324112, 78164.3126551977)]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM km_result_auto ORDER BY k;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Simplified silhouette for each point for a specific k value\n",
"\n",
"Let's say we want the simplified silhouette coefficient for each point in the data set, for the case where k=3:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"10 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>pid</th>\n",
" <th>centroid_id</th>\n",
" <th>neighbor_centroid_id</th>\n",
" <th>silh</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0.980239585977</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0.930766555398</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0.688472504844</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0.93681403243</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0.963483713725</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0.820493559534</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0.852729167645</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0.987157842135</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0.986876244672</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 2, 1, 0.980239585976511),\n",
" (2, 2, 0, 0.930766555397587),\n",
" (3, 2, 1, 0.68847250484412),\n",
" (4, 1, 2, 0.936814032430309),\n",
" (5, 0, 2, 1.0),\n",
" (6, 1, 2, 0.963483713724561),\n",
" (7, 1, 2, 0.820493559534495),\n",
" (8, 1, 2, 0.852729167645154),\n",
" (9, 2, 0, 0.987157842134726),\n",
" (10, 2, 0, 0.98687624467199)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS km_points_silh;\n",
"\n",
"SELECT * FROM madlib.simple_silhouette_points( 'km_sample', -- Input points table\n",
" 'km_points_silh', -- Output table\n",
" 'pid', -- Point ID column in input table\n",
" 'points', -- Points column in input table\n",
" (SELECT centroids FROM km_result_auto WHERE k=3), -- centroids array\n",
" 'madlib.squared_dist_norm2' -- Distance function\n",
" );\n",
"\n",
"SELECT * FROM km_points_silh ORDER BY pid;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}