blob: 2f3d51a180a2fc316cbdbeb5f1399a3beeaaa997 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# k-Nearest Neighbors\n",
"Finds k nearest data points to a given data point and outputs majority vote value of output classes in case of classification, and average value of target values in case of regression. KNN was first added in MADlib 1.10 with multiple updates in subsequent releases."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: gpadmin@madlib'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP (demo machine)\n",
"%sql postgresql://gpadmin@35.184.232.200:5432/madlib\n",
" \n",
"# PostgreSQL local\n",
"#%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.16-dev, git revision: rc/1.15.1-rc1-29-g0ba6155, cmake configuration time: Wed Feb 20 17:40:16 UTC 2019, build type: release, build system: Linux-2.6.32-754.6.3.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.16-dev, git revision: rc/1.15.1-rc1-29-g0ba6155, cmake configuration time: Wed Feb 20 17:40:16 UTC 2019, build type: release, build system: Linux-2.6.32-754.6.3.el6.x86_64, C compiler: gcc 4.4.7, C++ compiler: g++ 4.4.7',)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Load data for classification"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"9 rows affected.\n",
"9 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>label</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[1, 1]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 2]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[3, 3]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[4, 4]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[4, 5]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[20, 50]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[10, 31]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[81, 13]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[1, 111]</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [1, 1], 1),\n",
" (2, [2, 2], 1),\n",
" (3, [3, 3], 1),\n",
" (4, [4, 4], 1),\n",
" (5, [4, 5], 1),\n",
" (6, [20, 50], 0),\n",
" (7, [10, 31], 0),\n",
" (8, [81, 13], 0),\n",
" (9, [1, 111], 0)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS knn_train_data;\n",
"\n",
"CREATE TABLE knn_train_data (\n",
" id integer, \n",
" data integer[], \n",
" label integer -- Integer label means for classification\n",
" );\n",
"\n",
"INSERT INTO knn_train_data VALUES\n",
"(1, '{1,1}', 1),\n",
"(2, '{2,2}', 1),\n",
"(3, '{3,3}', 1),\n",
"(4, '{4,4}', 1),\n",
"(5, '{4,5}', 1),\n",
"(6, '{20,50}', 0),\n",
"(7, '{10,31}', 0),\n",
"(8, '{81,13}', 0),\n",
"(9, '{1,111}', 0);\n",
"\n",
"SELECT * FROM knn_train_data ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Load data for regression"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"9 rows affected.\n",
"9 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>label</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[1, 1]</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 2]</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[3, 3]</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[4, 4]</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[4, 5]</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[20, 50]</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>[10, 31]</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>[81, 13]</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>[1, 111]</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [1, 1], 1.0),\n",
" (2, [2, 2], 1.0),\n",
" (3, [3, 3], 1.0),\n",
" (4, [4, 4], 1.0),\n",
" (5, [4, 5], 1.0),\n",
" (6, [20, 50], 0.0),\n",
" (7, [10, 31], 0.0),\n",
" (8, [81, 13], 0.0),\n",
" (9, [1, 111], 0.0)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_train_data_reg;\n",
"\n",
"CREATE TABLE knn_train_data_reg (\n",
" id integer, \n",
" data integer[], \n",
" label float -- Float label means for regression\n",
" );\n",
"\n",
"INSERT INTO knn_train_data_reg VALUES\n",
"(1, '{1,1}', 1.0),\n",
"(2, '{2,2}', 1.0),\n",
"(3, '{3,3}', 1.0),\n",
"(4, '{4,4}', 1.0),\n",
"(5, '{4,5}', 1.0),\n",
"(6, '{20,50}', 0.0),\n",
"(7, '{10,31}', 0.0),\n",
"(8, '{81,13}', 0.0),\n",
"(9, '{1,111}', 0.0);\n",
"\n",
"SELECT * FROM knn_train_data_reg ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Load testing data"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"6 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1]),\n",
" (2, [2, 6]),\n",
" (3, [15, 40]),\n",
" (4, [12, 1]),\n",
" (5, [2, 90]),\n",
" (6, [50, 45])]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql \n",
"DROP TABLE IF EXISTS knn_test_data;\n",
"\n",
"CREATE TABLE knn_test_data (\n",
" id integer, \n",
" data integer[]\n",
" );\n",
"\n",
"INSERT INTO knn_test_data VALUES\n",
"(1, '{2,1}'),\n",
"(2, '{2,6}'),\n",
"(3, '{15,40}'),\n",
"(4, '{12,1}'),\n",
"(5, '{2,90}'),\n",
"(6, '{50,45}');\n",
"\n",
"SELECT * from knn_test_data ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Run KNN for classification\n",
"Note that the nearest neighbors are sorted from closest to furthest from the corresponding test point."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>prediction</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>1.0</td>\n",
" <td>[2, 1, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>1.0</td>\n",
" <td>[5, 4, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>0.0</td>\n",
" <td>[7, 6, 5]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" <td>1.0</td>\n",
" <td>[4, 5, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>0.0</td>\n",
" <td>[9, 6, 7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>0.0</td>\n",
" <td>[6, 7, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], 1.0, [2, 1, 3]),\n",
" (2, [2, 6], 1.0, [5, 4, 3]),\n",
" (3, [15, 40], 0.0, [7, 6, 5]),\n",
" (4, [12, 1], 1.0, [4, 5, 3]),\n",
" (5, [2, 90], 0.0, [9, 6, 7]),\n",
" (6, [50, 45], 0.0, [6, 7, 8])]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_classification;\n",
"\n",
"SELECT * FROM madlib.knn(\n",
" 'knn_train_data', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col name of id in train data\n",
" 'label', -- Training labels\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_classification', -- Output table\n",
" 3, -- Number of nearest neighbors\n",
" True, -- True to list nearest-neighbors by id\n",
" 'madlib.squared_dist_norm2' -- Distance function\n",
" );\n",
"\n",
"SELECT * from knn_result_classification ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Run KNN for regression"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>prediction</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>1.0</td>\n",
" <td>[2, 1, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>1.0</td>\n",
" <td>[5, 4, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>0.333333333333</td>\n",
" <td>[7, 6, 5]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" <td>1.0</td>\n",
" <td>[4, 5, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>0.0</td>\n",
" <td>[9, 6, 7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>0.0</td>\n",
" <td>[6, 7, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], 1.0, [2, 1, 3]),\n",
" (2, [2, 6], 1.0, [5, 4, 3]),\n",
" (3, [15, 40], 0.333333333333333, [7, 6, 5]),\n",
" (4, [12, 1], 1.0, [4, 5, 3]),\n",
" (5, [2, 90], 0.0, [9, 6, 7]),\n",
" (6, [50, 45], 0.0, [6, 7, 8])]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_regression;\n",
"\n",
"SELECT * FROM madlib.knn(\n",
" 'knn_train_data_reg', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col Name of id in train data\n",
" 'label', -- Training labels\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_regression', -- Output table\n",
" 3, -- Number of nearest neighbors\n",
" True, -- True to list nearest-neighbors by id\n",
" 'madlib.dist_norm2' -- Distance function\n",
" );\n",
"\n",
"SELECT * FROM knn_result_regression ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. List nearest neighbors only\n",
"(without doing classification or regression). Note that the nearest neighbors are sorted from closest to furthest from the corresponding test point."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>[2, 1, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>[5, 4, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>[7, 6, 5]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" <td>[4, 5, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>[9, 6, 7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>[6, 7, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], [2, 1, 3]),\n",
" (2, [2, 6], [5, 4, 3]),\n",
" (3, [15, 40], [7, 6, 5]),\n",
" (4, [12, 1], [4, 5, 3]),\n",
" (5, [2, 90], [9, 6, 7]),\n",
" (6, [50, 45], [6, 7, 8])]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_list_neighbors;\n",
"\n",
"SELECT * FROM madlib.knn(\n",
" 'knn_train_data_reg', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col Name of id in train data\n",
" NULL, -- NULL training labels means just list neighbors\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_list_neighbors', -- Output table\n",
" 3 -- Number of nearest neighbors\n",
" );\n",
"\n",
"SELECT * FROM knn_result_list_neighbors ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7. Weighted average\n",
"Run classification using weighted average"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>prediction</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>1</td>\n",
" <td>[1, 2, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>1</td>\n",
" <td>[5, 4, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>0</td>\n",
" <td>[7, 6, 5]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" <td>1</td>\n",
" <td>[4, 5, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>0</td>\n",
" <td>[9, 6, 7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>0</td>\n",
" <td>[6, 7, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], 1, [1, 2, 3]),\n",
" (2, [2, 6], 1, [5, 4, 3]),\n",
" (3, [15, 40], 0, [7, 6, 5]),\n",
" (4, [12, 1], 1, [4, 5, 3]),\n",
" (5, [2, 90], 0, [9, 6, 7]),\n",
" (6, [50, 45], 0, [6, 7, 8])]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_classification;\n",
"\n",
"SELECT * FROM madlib.knn(\n",
" 'knn_train_data', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col name of id in train data\n",
" 'label', -- Training labels\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_classification', -- Output table\n",
" 3, -- Number of nearest neighbors\n",
" True, -- True to list nearest-neighbors by id\n",
" 'madlib.squared_dist_norm2', -- Distance function\n",
" True -- For weighted average\n",
" );\n",
"\n",
"SELECT * FROM knn_result_classification ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 8. Use kd-tree algorithm \n",
"Here we build a kd-tree to depth 4 and search half (8) of the 16 leaf nodes (i.e., 2^4)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"6 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>[1, 2, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>[5, 4, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>[7, 6, 5]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>[12, 1]</td>\n",
" <td>[4, 5, 3]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>[9, 6, 7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>[6, 7, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], [1, 2, 3]),\n",
" (2, [2, 6], [5, 4, 3]),\n",
" (3, [15, 40], [7, 6, 5]),\n",
" (4, [12, 1], [4, 5, 3]),\n",
" (5, [2, 90], [9, 6, 7]),\n",
" (6, [50, 45], [6, 7, 8])]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_classification_kd;\n",
"\n",
"SELECT madlib.knn(\n",
" 'knn_train_data', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col name of id in train data\n",
" NULL, -- Training labels\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_classification_kd', -- Output table\n",
" 3, -- Number of nearest neighbors\n",
" True, -- True to list nearest-neighbors by id\n",
" 'madlib.squared_dist_norm2', -- Distance function\n",
" False, -- For weighted average\n",
" 'kd_tree', -- Use kd-tree\n",
" 'depth=4, leaf_nodes=8' -- Kd-tree options\n",
" \n",
" );\n",
"SELECT * FROM knn_result_classification_kd ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result above is the same as brute force. If we search just 1 leaf node, run-time will be faster but accuracy will be lower. This shows up in this very small data set by not being able to find 3 nearest neighbors for all test points:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>data</th>\n",
" <th>k_nearest_neighbours</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>[2, 1]</td>\n",
" <td>[1]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>[2, 6]</td>\n",
" <td>[3, 2]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>[15, 40]</td>\n",
" <td>[7]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>[2, 90]</td>\n",
" <td>[3, 2]</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>[50, 45]</td>\n",
" <td>[6, 8]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, [2, 1], [1]),\n",
" (2, [2, 6], [3, 2]),\n",
" (3, [15, 40], [7]),\n",
" (5, [2, 90], [3, 2]),\n",
" (6, [50, 45], [6, 8])]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS knn_result_classification_kd;\n",
"\n",
"SELECT madlib.knn(\n",
" 'knn_train_data', -- Table of training data\n",
" 'data', -- Col name of training data\n",
" 'id', -- Col name of id in train data\n",
" NULL, -- Training labels\n",
" 'knn_test_data', -- Table of test data\n",
" 'data', -- Col name of test data\n",
" 'id', -- Col name of id in test data\n",
" 'knn_result_classification_kd', -- Output table\n",
" 3, -- Number of nearest neighbors\n",
" True, -- True to list nearest-neighbors by id\n",
" 'madlib.squared_dist_norm2', -- Distance function\n",
" False, -- For weighted average\n",
" 'kd_tree', -- Use kd-tree\n",
" 'depth=4, leaf_nodes=1' -- Kd-tree options\n",
" \n",
" );\n",
"SELECT * FROM knn_result_classification_kd ORDER BY id;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}