| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# K-means clustering\n", |
| "This module can compute clusters given the number of centroids k as an input, using a variety of seeding methods. It can also automatically select the best k value from a range of suggested k values, using the simplified silhouette method or the elbow method.\n", |
| "\n", |
| "## Table of contents\n", |
| "\n", |
| "<a href=\"#setup\">0. Setup</a>\n", |
| "\n", |
| "<a href=\"#single_k\">1. Clustering for single k value</a>\n", |
| "\n", |
| "<a href=\"#range_k\">2. Clustering for a range of k values</a>" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "<a id=\"setup\"></a>\n", |
| "# 0. Setup" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n", |
| " \"You should import from traitlets.config instead.\", ShimWarning)\n", |
| "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", |
| " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n" |
| ] |
| } |
| ], |
| "source": [ |
| "%load_ext sql" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "u'Connected: gpadmin@madlib'" |
| ] |
| }, |
| "execution_count": 2, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n", |
| "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n", |
| "\n", |
| "# Greenplum Database 5.x on GCP - via tunnel\n", |
| "%sql postgresql://gpadmin@localhost:8000/madlib\n", |
| " \n", |
| "# PostgreSQL local\n", |
| "#%sql postgresql://fmcquillan@localhost:5432/madlib" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>version</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-54-gec5614f, cmake configuration time: Wed Dec 18 17:08:05 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]" |
| ] |
| }, |
| "execution_count": 3, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%sql select madlib.version();\n", |
| "#%sql select version();" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "<a id=\"single_k\"></a>\n", |
| "# Clustering for single k value\n", |
| "\n", |
| "# 1. Input data" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "Done.\n", |
| "10 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>points</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>[13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>[14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>[14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>[13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0]</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, [14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]),\n", |
| " (2, [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]),\n", |
| " (3, [13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]),\n", |
| " (4, [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]),\n", |
| " (5, [13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]),\n", |
| " (6, [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]),\n", |
| " (7, [14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]),\n", |
| " (8, [14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]),\n", |
| " (9, [14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]),\n", |
| " (10, [13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0])]" |
| ] |
| }, |
| "execution_count": 4, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_sample;\n", |
| "\n", |
| "CREATE TABLE km_sample(pid int, points double precision[]);\n", |
| "\n", |
| "INSERT INTO km_sample VALUES\n", |
| "(1, '{14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065}'),\n", |
| "(2, '{13.2, 1.78, 2.14, 11.2, 1, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050}'),\n", |
| "(3, '{13.16, 2.36, 2.67, 18.6, 101, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185}'),\n", |
| "(4, '{14.37, 1.95, 2.5, 16.8, 113, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480}'),\n", |
| "(5, '{13.24, 2.59, 2.87, 21, 118, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735}'),\n", |
| "(6, '{14.2, 1.76, 2.45, 15.2, 112, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450}'),\n", |
| "(7, '{14.39, 1.87, 2.45, 14.6, 96, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290}'),\n", |
| "(8, '{14.06, 2.15, 2.61, 17.6, 121, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295}'),\n", |
| "(9, '{14.83, 1.64, 2.17, 14, 97, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045}'),\n", |
| "(10, '{13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, 3.55, 1045}');\n", |
| "\n", |
| "SELECT * FROM km_sample ORDER BY pid;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 2. Run k-means clustering using kmeans++ with centroid seeding\n", |
| "Use squared Euclidean distance which is a commonly used distance function." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>centroids</th>\n", |
| " <th>cluster_variance</th>\n", |
| " <th>objective_fn</th>\n", |
| " <th>frac_reassigned</th>\n", |
| " <th>num_iterations</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>[[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]]</td>\n", |
| " <td>[90512.324426408, 60672.638245208]</td>\n", |
| " <td>151184.962672</td>\n", |
| " <td>0.0</td>\n", |
| " <td>2</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[([[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]], [90512.324426408, 60672.638245208], 151184.962671616, 0.0, 2)]" |
| ] |
| }, |
| "execution_count": 5, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_result;\n", |
| "\n", |
| "CREATE TABLE km_result AS\n", |
| "SELECT * FROM madlib.kmeanspp( 'km_sample', -- Table of source data\n", |
| " 'points', -- Column containing point co-ordinates \n", |
| " 2, -- Number of centroids to calculate\n", |
| " 'madlib.squared_dist_norm2', -- Distance function\n", |
| " 'madlib.avg', -- Aggregate function\n", |
| " 20, -- Number of iterations\n", |
| " 0.001 -- Fraction of centroids reassigned to keep iterating \n", |
| " );\n", |
| "\n", |
| "SELECT * FROM km_result;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 3. Simplified silhouette coefficient\n", |
| "Average for whole data set. Make sure to use the same distance function as k-means above." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>simple_silhouette</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>0.872087020147</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(0.872087020146542,)]" |
| ] |
| }, |
| "execution_count": 6, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM madlib.simple_silhouette( 'km_sample', -- Input points table\n", |
| " 'points', -- Column containing points\n", |
| " (SELECT centroids FROM km_result), -- Centroids\n", |
| " 'madlib.squared_dist_norm2' -- Distance function\n", |
| " );" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Now calculate simplified silhouette coefficient for each point in the data set:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>centroid_id</th>\n", |
| " <th>neighbor_centroid_id</th>\n", |
| " <th>silh</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>1</td>\n", |
| " <td>0.902123603766</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>0</td>\n", |
| " <td>1</td>\n", |
| " <td>0.88017393665</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>0.382089480836</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>0.919141622264</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>0</td>\n", |
| " <td>1</td>\n", |
| " <td>0.822664572979</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>0.943394365443</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>0.970809939945</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>1</td>\n", |
| " <td>0</td>\n", |
| " <td>0.977109993192</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>0</td>\n", |
| " <td>1</td>\n", |
| " <td>0.961796151461</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>0</td>\n", |
| " <td>1</td>\n", |
| " <td>0.961566534928</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 0, 1, 0.902123603765829),\n", |
| " (2, 0, 1, 0.880173936650081),\n", |
| " (3, 1, 0, 0.382089480836045),\n", |
| " (4, 1, 0, 0.919141622264229),\n", |
| " (5, 0, 1, 0.822664572979012),\n", |
| " (6, 1, 0, 0.943394365442933),\n", |
| " (7, 1, 0, 0.97080993994542),\n", |
| " (8, 1, 0, 0.977109993192101),\n", |
| " (9, 0, 1, 0.961796151461416),\n", |
| " (10, 0, 1, 0.961566534928355)]" |
| ] |
| }, |
| "execution_count": 7, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_points_silh;\n", |
| "\n", |
| "SELECT * FROM madlib.simple_silhouette_points( 'km_sample', -- Input points table\n", |
| " 'km_points_silh', -- Output table\n", |
| " 'pid', -- Point ID column in input table\n", |
| " 'points', -- Points column in input table\n", |
| " 'km_result', -- Centroids table\n", |
| " 'centroids', -- Column in centroids table containing centroids \n", |
| " 'madlib.squared_dist_norm2' -- Distance function\n", |
| " );\n", |
| "\n", |
| "SELECT * FROM km_points_silh ORDER BY pid;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 4. Cluster assignment for each point\n", |
| "Use the closest_column() function to map each point to the cluster that it belongs to. " |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "10 rows affected.\n", |
| "Done.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>points</th>\n", |
| " <th>cluster_id</th>\n", |
| " <th>distance</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>[14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]</td>\n", |
| " <td>0</td>\n", |
| " <td>7435.21009152</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>[13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]</td>\n", |
| " <td>0</td>\n", |
| " <td>11468.6944211</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0]</td>\n", |
| " <td>1</td>\n", |
| " <td>24088.7121562</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>[14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0]</td>\n", |
| " <td>1</td>\n", |
| " <td>19623.3917358</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]</td>\n", |
| " <td>0</td>\n", |
| " <td>64929.1866187</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>[14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0]</td>\n", |
| " <td>1</td>\n", |
| " <td>12114.3606538</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>[14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0]</td>\n", |
| " <td>1</td>\n", |
| " <td>2664.07117376</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>[14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0]</td>\n", |
| " <td>1</td>\n", |
| " <td>2182.10252576</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>[14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0]</td>\n", |
| " <td>0</td>\n", |
| " <td>3330.15771392</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>[13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0]</td>\n", |
| " <td>0</td>\n", |
| " <td>3349.07558113</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, [14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0], 0, 7435.2100915204),\n", |
| " (2, [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], 0, 11468.6944211204),\n", |
| " (3, [13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], 1, 24088.7121561664),\n", |
| " (4, [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], 1, 19623.3917357604),\n", |
| " (5, [13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], 0, 64929.1866187204),\n", |
| " (6, [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], 1, 12114.3606537604),\n", |
| " (7, [14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0], 1, 2664.0711737604),\n", |
| " (8, [14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0], 1, 2182.1025257604),\n", |
| " (9, [14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0], 0, 3330.1577139204),\n", |
| " (10, [13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0], 0, 3349.0755811264)]" |
| ] |
| }, |
| "execution_count": 8, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS point_cluster_map;\n", |
| "\n", |
| "CREATE TABLE point_cluster_map AS \n", |
| "SELECT data.*, (madlib.closest_column(centroids, points, 'madlib.squared_dist_norm2')).*\n", |
| "FROM km_sample as data, km_result;\n", |
| "\n", |
| "ALTER TABLE point_cluster_map RENAME column_id to cluster_id; -- change column name\n", |
| "SELECT * FROM point_cluster_map ORDER BY pid;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 5. Display cluster ID\n", |
| "There are two steps to get the cluster id associated with the centroid coordinates, if you need it. First unnest the cluster centroids 2-D array to get a set of 1-D centroid arrays:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "2 rows affected.\n", |
| "2 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>unnest_row_id</th>\n", |
| " <th>unnest_result</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>[14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]),\n", |
| " (2, [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0])]" |
| ] |
| }, |
| "execution_count": 9, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_centroids_unnest;\n", |
| "\n", |
| "-- Run unnest function\n", |
| "CREATE TABLE km_centroids_unnest AS\n", |
| "SELECT (madlib.array_unnest_2d_to_1d(centroids)).*\n", |
| "FROM km_result;\n", |
| "\n", |
| "SELECT * FROM km_centroids_unnest ORDER BY 1;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Note that the ID column returned by 'array_unnest_2d_to_1d()' is just a row ID and not the cluster ID assigned by k-means. The second step to get the cluster_id is:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "2 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>unnest_row_id</th>\n", |
| " <th>unnest_result</th>\n", |
| " <th>cluster_id</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>[13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>[14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0]</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0], 0),\n", |
| " (2, [14.036, 2.018, 2.536, 16.56, 108.6, 3.004, 3.03, 0.298, 2.038, 6.10598, 1.004, 3.326, 1340.0], 1)]" |
| ] |
| }, |
| "execution_count": 10, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT cent.*, (madlib.closest_column(centroids, unnest_result, 'madlib.squared_dist_norm2')).column_id as cluster_id\n", |
| "FROM km_centroids_unnest as cent, km_result\n", |
| "ORDER BY cent.unnest_row_id;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 6. Array input\n", |
| "Create the input table:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "Done.\n", |
| "10 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>p1</th>\n", |
| " <th>p2</th>\n", |
| " <th>p3</th>\n", |
| " <th>p4</th>\n", |
| " <th>p5</th>\n", |
| " <th>p6</th>\n", |
| " <th>p7</th>\n", |
| " <th>p8</th>\n", |
| " <th>p9</th>\n", |
| " <th>p10</th>\n", |
| " <th>p11</th>\n", |
| " <th>p12</th>\n", |
| " <th>p13</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>14.23</td>\n", |
| " <td>1.71</td>\n", |
| " <td>2.43</td>\n", |
| " <td>15.6</td>\n", |
| " <td>127.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>3.06</td>\n", |
| " <td>0.28</td>\n", |
| " <td>2.29</td>\n", |
| " <td>5.64</td>\n", |
| " <td>1.04</td>\n", |
| " <td>3.92</td>\n", |
| " <td>1065.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>13.2</td>\n", |
| " <td>1.78</td>\n", |
| " <td>2.14</td>\n", |
| " <td>11.2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>2.65</td>\n", |
| " <td>2.76</td>\n", |
| " <td>0.26</td>\n", |
| " <td>1.28</td>\n", |
| " <td>4.38</td>\n", |
| " <td>1.05</td>\n", |
| " <td>3.49</td>\n", |
| " <td>1050.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>13.16</td>\n", |
| " <td>2.36</td>\n", |
| " <td>2.67</td>\n", |
| " <td>18.6</td>\n", |
| " <td>101.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>3.24</td>\n", |
| " <td>0.3</td>\n", |
| " <td>2.81</td>\n", |
| " <td>5.6799</td>\n", |
| " <td>1.03</td>\n", |
| " <td>3.17</td>\n", |
| " <td>1185.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>14.37</td>\n", |
| " <td>1.95</td>\n", |
| " <td>2.5</td>\n", |
| " <td>16.8</td>\n", |
| " <td>113.0</td>\n", |
| " <td>3.85</td>\n", |
| " <td>3.49</td>\n", |
| " <td>0.24</td>\n", |
| " <td>2.18</td>\n", |
| " <td>7.8</td>\n", |
| " <td>0.86</td>\n", |
| " <td>3.45</td>\n", |
| " <td>1480.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>13.24</td>\n", |
| " <td>2.59</td>\n", |
| " <td>2.87</td>\n", |
| " <td>21.0</td>\n", |
| " <td>118.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>2.69</td>\n", |
| " <td>0.39</td>\n", |
| " <td>1.82</td>\n", |
| " <td>4.32</td>\n", |
| " <td>1.04</td>\n", |
| " <td>2.93</td>\n", |
| " <td>735.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>14.2</td>\n", |
| " <td>1.76</td>\n", |
| " <td>2.45</td>\n", |
| " <td>15.2</td>\n", |
| " <td>112.0</td>\n", |
| " <td>3.27</td>\n", |
| " <td>3.39</td>\n", |
| " <td>0.34</td>\n", |
| " <td>1.97</td>\n", |
| " <td>6.75</td>\n", |
| " <td>1.05</td>\n", |
| " <td>2.85</td>\n", |
| " <td>1450.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>14.39</td>\n", |
| " <td>1.87</td>\n", |
| " <td>2.45</td>\n", |
| " <td>14.6</td>\n", |
| " <td>96.0</td>\n", |
| " <td>2.5</td>\n", |
| " <td>2.52</td>\n", |
| " <td>0.3</td>\n", |
| " <td>1.98</td>\n", |
| " <td>5.25</td>\n", |
| " <td>1.02</td>\n", |
| " <td>3.58</td>\n", |
| " <td>1290.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>14.06</td>\n", |
| " <td>2.15</td>\n", |
| " <td>2.61</td>\n", |
| " <td>17.6</td>\n", |
| " <td>121.0</td>\n", |
| " <td>2.6</td>\n", |
| " <td>2.51</td>\n", |
| " <td>0.31</td>\n", |
| " <td>1.25</td>\n", |
| " <td>5.05</td>\n", |
| " <td>1.06</td>\n", |
| " <td>3.58</td>\n", |
| " <td>1295.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>14.83</td>\n", |
| " <td>1.64</td>\n", |
| " <td>2.17</td>\n", |
| " <td>14.0</td>\n", |
| " <td>97.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>2.98</td>\n", |
| " <td>0.29</td>\n", |
| " <td>1.98</td>\n", |
| " <td>5.2</td>\n", |
| " <td>1.08</td>\n", |
| " <td>2.85</td>\n", |
| " <td>1045.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>13.86</td>\n", |
| " <td>1.35</td>\n", |
| " <td>2.27</td>\n", |
| " <td>16.0</td>\n", |
| " <td>98.0</td>\n", |
| " <td>2.98</td>\n", |
| " <td>3.15</td>\n", |
| " <td>0.22</td>\n", |
| " <td>1.85</td>\n", |
| " <td>7.2199</td>\n", |
| " <td>1.01</td>\n", |
| " <td>3.55</td>\n", |
| " <td>1045.0</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0),\n", |
| " (2, 13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0),\n", |
| " (3, 13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0),\n", |
| " (4, 14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0),\n", |
| " (5, 13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0),\n", |
| " (6, 14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0),\n", |
| " (7, 14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0),\n", |
| " (8, 14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0),\n", |
| " (9, 14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0),\n", |
| " (10, 13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0)]" |
| ] |
| }, |
| "execution_count": 11, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_arrayin CASCADE;\n", |
| "\n", |
| "CREATE TABLE km_arrayin(pid int, \n", |
| " p1 float, \n", |
| " p2 float, \n", |
| " p3 float,\n", |
| " p4 float, \n", |
| " p5 float, \n", |
| " p6 float,\n", |
| " p7 float, \n", |
| " p8 float, \n", |
| " p9 float,\n", |
| " p10 float, \n", |
| " p11 float, \n", |
| " p12 float,\n", |
| " p13 float);\n", |
| "\n", |
| "INSERT INTO km_arrayin VALUES\n", |
| "(1, 14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065),\n", |
| "(2, 13.2, 1.78, 2.14, 11.2, 1, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050),\n", |
| "(3, 13.16, 2.36, 2.67, 18.6, 101, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185),\n", |
| "(4, 14.37, 1.95, 2.5, 16.8, 113, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480),\n", |
| "(5, 13.24, 2.59, 2.87, 21, 118, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735),\n", |
| "(6, 14.2, 1.76, 2.45, 15.2, 112, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450),\n", |
| "(7, 14.39, 1.87, 2.45, 14.6, 96, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290),\n", |
| "(8, 14.06, 2.15, 2.61, 17.6, 121, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295),\n", |
| "(9, 14.83, 1.64, 2.17, 14, 97, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045),\n", |
| "(10, 13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, 3.55, 1045);\n", |
| "\n", |
| "SELECT * FROM km_arrayin ORDER BY pid;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "Now find the cluster assignment for each point using random seeding:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "metadata": { |
| "scrolled": true |
| }, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>p1</th>\n", |
| " <th>p2</th>\n", |
| " <th>p3</th>\n", |
| " <th>p4</th>\n", |
| " <th>p5</th>\n", |
| " <th>p6</th>\n", |
| " <th>p7</th>\n", |
| " <th>p8</th>\n", |
| " <th>p9</th>\n", |
| " <th>p10</th>\n", |
| " <th>p11</th>\n", |
| " <th>p12</th>\n", |
| " <th>p13</th>\n", |
| " <th>cluster_id</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>14.23</td>\n", |
| " <td>1.71</td>\n", |
| " <td>2.43</td>\n", |
| " <td>15.6</td>\n", |
| " <td>127.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>3.06</td>\n", |
| " <td>0.28</td>\n", |
| " <td>2.29</td>\n", |
| " <td>5.64</td>\n", |
| " <td>1.04</td>\n", |
| " <td>3.92</td>\n", |
| " <td>1065.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>13.2</td>\n", |
| " <td>1.78</td>\n", |
| " <td>2.14</td>\n", |
| " <td>11.2</td>\n", |
| " <td>1.0</td>\n", |
| " <td>2.65</td>\n", |
| " <td>2.76</td>\n", |
| " <td>0.26</td>\n", |
| " <td>1.28</td>\n", |
| " <td>4.38</td>\n", |
| " <td>1.05</td>\n", |
| " <td>3.49</td>\n", |
| " <td>1050.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>13.16</td>\n", |
| " <td>2.36</td>\n", |
| " <td>2.67</td>\n", |
| " <td>18.6</td>\n", |
| " <td>101.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>3.24</td>\n", |
| " <td>0.3</td>\n", |
| " <td>2.81</td>\n", |
| " <td>5.6799</td>\n", |
| " <td>1.03</td>\n", |
| " <td>3.17</td>\n", |
| " <td>1185.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>14.37</td>\n", |
| " <td>1.95</td>\n", |
| " <td>2.5</td>\n", |
| " <td>16.8</td>\n", |
| " <td>113.0</td>\n", |
| " <td>3.85</td>\n", |
| " <td>3.49</td>\n", |
| " <td>0.24</td>\n", |
| " <td>2.18</td>\n", |
| " <td>7.8</td>\n", |
| " <td>0.86</td>\n", |
| " <td>3.45</td>\n", |
| " <td>1480.0</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>13.24</td>\n", |
| " <td>2.59</td>\n", |
| " <td>2.87</td>\n", |
| " <td>21.0</td>\n", |
| " <td>118.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>2.69</td>\n", |
| " <td>0.39</td>\n", |
| " <td>1.82</td>\n", |
| " <td>4.32</td>\n", |
| " <td>1.04</td>\n", |
| " <td>2.93</td>\n", |
| " <td>735.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>14.2</td>\n", |
| " <td>1.76</td>\n", |
| " <td>2.45</td>\n", |
| " <td>15.2</td>\n", |
| " <td>112.0</td>\n", |
| " <td>3.27</td>\n", |
| " <td>3.39</td>\n", |
| " <td>0.34</td>\n", |
| " <td>1.97</td>\n", |
| " <td>6.75</td>\n", |
| " <td>1.05</td>\n", |
| " <td>2.85</td>\n", |
| " <td>1450.0</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>14.39</td>\n", |
| " <td>1.87</td>\n", |
| " <td>2.45</td>\n", |
| " <td>14.6</td>\n", |
| " <td>96.0</td>\n", |
| " <td>2.5</td>\n", |
| " <td>2.52</td>\n", |
| " <td>0.3</td>\n", |
| " <td>1.98</td>\n", |
| " <td>5.25</td>\n", |
| " <td>1.02</td>\n", |
| " <td>3.58</td>\n", |
| " <td>1290.0</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>14.06</td>\n", |
| " <td>2.15</td>\n", |
| " <td>2.61</td>\n", |
| " <td>17.6</td>\n", |
| " <td>121.0</td>\n", |
| " <td>2.6</td>\n", |
| " <td>2.51</td>\n", |
| " <td>0.31</td>\n", |
| " <td>1.25</td>\n", |
| " <td>5.05</td>\n", |
| " <td>1.06</td>\n", |
| " <td>3.58</td>\n", |
| " <td>1295.0</td>\n", |
| " <td>0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>14.83</td>\n", |
| " <td>1.64</td>\n", |
| " <td>2.17</td>\n", |
| " <td>14.0</td>\n", |
| " <td>97.0</td>\n", |
| " <td>2.8</td>\n", |
| " <td>2.98</td>\n", |
| " <td>0.29</td>\n", |
| " <td>1.98</td>\n", |
| " <td>5.2</td>\n", |
| " <td>1.08</td>\n", |
| " <td>2.85</td>\n", |
| " <td>1045.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>13.86</td>\n", |
| " <td>1.35</td>\n", |
| " <td>2.27</td>\n", |
| " <td>16.0</td>\n", |
| " <td>98.0</td>\n", |
| " <td>2.98</td>\n", |
| " <td>3.15</td>\n", |
| " <td>0.22</td>\n", |
| " <td>1.85</td>\n", |
| " <td>7.2199</td>\n", |
| " <td>1.01</td>\n", |
| " <td>3.55</td>\n", |
| " <td>1045.0</td>\n", |
| " <td>1</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0, 1),\n", |
| " (2, 13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0, 1),\n", |
| " (3, 13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0, 1),\n", |
| " (4, 14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0, 0),\n", |
| " (5, 13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0, 1),\n", |
| " (6, 14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0, 0),\n", |
| " (7, 14.39, 1.87, 2.45, 14.6, 96.0, 2.5, 2.52, 0.3, 1.98, 5.25, 1.02, 3.58, 1290.0, 0),\n", |
| " (8, 14.06, 2.15, 2.61, 17.6, 121.0, 2.6, 2.51, 0.31, 1.25, 5.05, 1.06, 3.58, 1295.0, 0),\n", |
| " (9, 14.83, 1.64, 2.17, 14.0, 97.0, 2.8, 2.98, 0.29, 1.98, 5.2, 1.08, 2.85, 1045.0, 1),\n", |
| " (10, 13.86, 1.35, 2.27, 16.0, 98.0, 2.98, 3.15, 0.22, 1.85, 7.2199, 1.01, 3.55, 1045.0, 1)]" |
| ] |
| }, |
| "execution_count": 12, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_result_array;\n", |
| "\n", |
| "-- Run kmeans algorithm\n", |
| "CREATE TABLE km_result_array AS\n", |
| "SELECT * FROM madlib.kmeans_random('km_arrayin', -- Table of source data\n", |
| " 'ARRAY[p1, p2, p3, p4, p5, p6, -- Points\n", |
| " p7, p8, p9, p10, p11, p12, p13]', \n", |
| " 2, -- Number of centroids to calculate\n", |
| " 'madlib.squared_dist_norm2', -- Distance function\n", |
| " 'madlib.avg', -- Aggregate function\n", |
| " 20, -- Number of iterations\n", |
| " 0.001); -- Fraction of centroids reassigned to keep iterating \n", |
| "\n", |
| "-- Get point assignment\n", |
| "SELECT data.*, (madlib.closest_column(centroids, \n", |
| " ARRAY[p1, p2, p3, p4, p5, p6, \n", |
| " p7, p8, p9, p10, p11, p12, p13], \n", |
| " 'madlib.squared_dist_norm2')).column_id as cluster_id\n", |
| "FROM km_arrayin as data, km_result_array\n", |
| "ORDER BY data.pid;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "<a id=\"range_k\"></a>\n", |
| "# Clustering for a range of k values\n", |
| "\n", |
| "# 1. Auto k selection\n", |
| "Now let's run k-means random for a variety of k values and compare using simple silhouette and elbow methods." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "1 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>k</th>\n", |
| " <th>centroids</th>\n", |
| " <th>cluster_variance</th>\n", |
| " <th>objective_fn</th>\n", |
| " <th>frac_reassigned</th>\n", |
| " <th>num_iterations</th>\n", |
| " <th>silhouette</th>\n", |
| " <th>elbow</th>\n", |
| " <th>selection_algorithm</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]]</td>\n", |
| " <td>[0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633]</td>\n", |
| " <td>9384.14038935</td>\n", |
| " <td>0.0</td>\n", |
| " <td>3</td>\n", |
| " <td>0.953601354123</td>\n", |
| " <td>44857.4267179</td>\n", |
| " <td>silhouette</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(5, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]], [0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633], 9384.14038934667, 0.0, 3, 0.953601354123412, 44857.4267178967, u'silhouette')]" |
| ] |
| }, |
| "execution_count": 13, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_result_auto, km_result_auto_summary;\n", |
| "\n", |
| "SELECT madlib.kmeans_random_auto(\n", |
| " 'km_sample', -- points table\n", |
| " 'km_result_auto', -- output table\n", |
| " 'points', -- column name in point table\n", |
| " ARRAY[2,3,4,5,6], -- k values to try\n", |
| " 'madlib.squared_dist_norm2', -- distance function\n", |
| " 'madlib.avg', -- aggregate function\n", |
| " 20, -- max iterations\n", |
| " 0.001, -- minimum fraction of centroids reassigned to continue iterating\n", |
| " 'both' -- k selection algorithm (simple silhouette and elbow)\n", |
| ");\n", |
| "\n", |
| "SELECT * FROM km_result_auto_summary;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "The best selection above is made by the silhouette algorithm by default. Note that the elbow method may select a different k value as the best. To see results for all k values:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "5 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>k</th>\n", |
| " <th>centroids</th>\n", |
| " <th>cluster_variance</th>\n", |
| " <th>objective_fn</th>\n", |
| " <th>frac_reassigned</th>\n", |
| " <th>num_iterations</th>\n", |
| " <th>silhouette</th>\n", |
| " <th>elbow</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>[[13.7533333333333, 1.905, 2.425, 16.0666666666667, 90.3333333333333, 2.805, 2.98, 0.29, 2.005, 5.40663333333333, 1.04166666666667, 3.31833333333333, 1020.83333333333], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75]]</td>\n", |
| " <td>[122999.110416013, 30561.74805]</td>\n", |
| " <td>153560.858466</td>\n", |
| " <td>0.0</td>\n", |
| " <td>2</td>\n", |
| " <td>0.86817460894</td>\n", |
| " <td>71506.2870379</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75], [13.856, 1.768, 2.336, 15.08, 84.8, 2.806, 3.038, 0.27, 2.042, 5.62396, 1.042, 3.396, 1078.0]]</td>\n", |
| " <td>[0.0, 30561.74805, 24007.669589612]</td>\n", |
| " <td>54569.4176396</td>\n", |
| " <td>0.0</td>\n", |
| " <td>4</td>\n", |
| " <td>0.914703320636</td>\n", |
| " <td>38199.4011006</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>[[13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]]</td>\n", |
| " <td>[8078.22646267333, 0.0, 0.0, 90512.324426408]</td>\n", |
| " <td>98590.5508891</td>\n", |
| " <td>0.0</td>\n", |
| " <td>3</td>\n", |
| " <td>0.894959666277</td>\n", |
| " <td>8221.52797196</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>[[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]]</td>\n", |
| " <td>[0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633]</td>\n", |
| " <td>9384.14038935</td>\n", |
| " <td>0.0</td>\n", |
| " <td>3</td>\n", |
| " <td>0.953601354123</td>\n", |
| " <td>44857.4267179</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>[[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [14.04, 1.8225, 2.435, 16.65, 110.0, 2.845, 2.97, 0.295, 1.985, 5.594975, 1.0425, 3.3125, 972.5], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.225, 2.01, 2.53, 16.1, 108.5, 2.55, 2.515, 0.305, 1.615, 5.15, 1.04, 3.58, 1292.5], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]]</td>\n", |
| " <td>[0.0, 0.0, 76176.4564000075, 0.0, 329.8988, 0.0]</td>\n", |
| " <td>76506.3552</td>\n", |
| " <td>0.0</td>\n", |
| " <td>2</td>\n", |
| " <td>0.772762876324</td>\n", |
| " <td>78164.3126552</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(2, [[13.7533333333333, 1.905, 2.425, 16.0666666666667, 90.3333333333333, 2.805, 2.98, 0.29, 2.005, 5.40663333333333, 1.04166666666667, 3.31833333333333, 1020.83333333333], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75]], [122999.110416013, 30561.74805], 153560.858466013, 0.0, 2, 0.868174608939623, 71506.2870379353),\n", |
| " (3, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.255, 1.9325, 2.5025, 16.05, 110.5, 3.055, 2.9775, 0.2975, 1.845, 6.2125, 0.9975, 3.365, 1378.75], [13.856, 1.768, 2.336, 15.08, 84.8, 2.806, 3.038, 0.27, 2.042, 5.62396, 1.042, 3.396, 1078.0]], [0.0, 30561.74805, 24007.669589612], 54569.417639612, 0.0, 4, 0.914703320635945, 38199.4011006343),\n", |
| " (4, [[13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [13.872, 1.814, 2.376, 15.56, 88.2, 2.806, 2.928, 0.288, 1.844, 5.35198, 1.044, 3.348, 988.0]], [8078.22646267333, 0.0, 0.0, 90512.324426408], 98590.5508890814, 0.0, 3, 0.894959666276704, 8221.52797196453),\n", |
| " (5, [[13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0], [14.3066666666667, 1.56666666666667, 2.29, 15.2, 107.333333333333, 2.86, 3.06333333333333, 0.263333333333333, 2.04, 6.01996666666667, 1.04333333333333, 3.44, 1051.66666666667], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0], [13.87, 2.12666666666667, 2.57666666666667, 16.9333333333333, 106.0, 2.63333333333333, 2.75666666666667, 0.303333333333333, 2.01333333333333, 5.32663333333333, 1.03666666666667, 3.44333333333333, 1256.66666666667], [14.285, 1.855, 2.475, 16.0, 112.5, 3.56, 3.44, 0.29, 2.075, 7.275, 0.955, 3.15, 1465.0]], [0.0, 853.150626673333, 0.0, 8078.22646267333, 452.7633], 9384.14038934667, 0.0, 3, 0.953601354123412, 44857.4267178967),\n", |
| " (6, [[13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.6799, 1.03, 3.17, 1185.0], [14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0], [14.04, 1.8225, 2.435, 16.65, 110.0, 2.845, 2.97, 0.295, 1.985, 5.594975, 1.0425, 3.3125, 972.5], [14.2, 1.76, 2.45, 15.2, 112.0, 3.27, 3.39, 0.34, 1.97, 6.75, 1.05, 2.85, 1450.0], [14.225, 2.01, 2.53, 16.1, 108.5, 2.55, 2.515, 0.305, 1.615, 5.15, 1.04, 3.58, 1292.5], [13.2, 1.78, 2.14, 11.2, 1.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.49, 1050.0]], [0.0, 0.0, 76176.4564000075, 0.0, 329.8988, 0.0], 76506.3552000075, 0.0, 2, 0.772762876324112, 78164.3126551977)]" |
| ] |
| }, |
| "execution_count": 14, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "SELECT * FROM km_result_auto ORDER BY k;" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": {}, |
| "source": [ |
| "# 2. Simplified silhouette for each point for a specific k value\n", |
| "\n", |
| "Let's say we want the simplified silhouette coefficient for each point in the data set, for the case where k=3:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "Done.\n", |
| "1 rows affected.\n", |
| "10 rows affected.\n" |
| ] |
| }, |
| { |
| "data": { |
| "text/html": [ |
| "<table>\n", |
| " <tr>\n", |
| " <th>pid</th>\n", |
| " <th>centroid_id</th>\n", |
| " <th>neighbor_centroid_id</th>\n", |
| " <th>silh</th>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>1</td>\n", |
| " <td>2</td>\n", |
| " <td>1</td>\n", |
| " <td>0.980239585977</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>2</td>\n", |
| " <td>2</td>\n", |
| " <td>0</td>\n", |
| " <td>0.930766555398</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>3</td>\n", |
| " <td>2</td>\n", |
| " <td>1</td>\n", |
| " <td>0.688472504844</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>4</td>\n", |
| " <td>1</td>\n", |
| " <td>2</td>\n", |
| " <td>0.93681403243</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>5</td>\n", |
| " <td>0</td>\n", |
| " <td>2</td>\n", |
| " <td>1.0</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>6</td>\n", |
| " <td>1</td>\n", |
| " <td>2</td>\n", |
| " <td>0.963483713725</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>7</td>\n", |
| " <td>1</td>\n", |
| " <td>2</td>\n", |
| " <td>0.820493559534</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>8</td>\n", |
| " <td>1</td>\n", |
| " <td>2</td>\n", |
| " <td>0.852729167645</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>9</td>\n", |
| " <td>2</td>\n", |
| " <td>0</td>\n", |
| " <td>0.987157842135</td>\n", |
| " </tr>\n", |
| " <tr>\n", |
| " <td>10</td>\n", |
| " <td>2</td>\n", |
| " <td>0</td>\n", |
| " <td>0.986876244672</td>\n", |
| " </tr>\n", |
| "</table>" |
| ], |
| "text/plain": [ |
| "[(1, 2, 1, 0.980239585976511),\n", |
| " (2, 2, 0, 0.930766555397587),\n", |
| " (3, 2, 1, 0.68847250484412),\n", |
| " (4, 1, 2, 0.936814032430309),\n", |
| " (5, 0, 2, 1.0),\n", |
| " (6, 1, 2, 0.963483713724561),\n", |
| " (7, 1, 2, 0.820493559534495),\n", |
| " (8, 1, 2, 0.852729167645154),\n", |
| " (9, 2, 0, 0.987157842134726),\n", |
| " (10, 2, 0, 0.98687624467199)]" |
| ] |
| }, |
| "execution_count": 15, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "%%sql\n", |
| "DROP TABLE IF EXISTS km_points_silh;\n", |
| "\n", |
| "SELECT * FROM madlib.simple_silhouette_points( 'km_sample', -- Input points table\n", |
| " 'km_points_silh', -- Output table\n", |
| " 'pid', -- Point ID column in input table\n", |
| " 'points', -- Points column in input table\n", |
| " (SELECT centroids FROM km_result_auto WHERE k=3), -- centroids array\n", |
| " 'madlib.squared_dist_norm2' -- Distance function\n", |
| " );\n", |
| "\n", |
| "SELECT * FROM km_points_silh ORDER BY pid;" |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 2", |
| "language": "python", |
| "name": "python2" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 2 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython2", |
| "version": "2.7.10" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 1 |
| } |