Kmeans: Add automatic optimal cluster estimation
JIRA: MADLIB-1380
This commit adds the option to run k-means clustering algorithm for a range of
`k` values and get the optimal `k` with its associated cluster centers. It is
only supported for random and pp initial seeding options.
Closes #433
Co-authored-by: Nikhil Kak <nkak@pivotal.io>
Co-authored-by: Ekta Khanna <ekhanna@pivotal.io>
diff --git a/src/ports/postgres/modules/kmeans/kmeans.sql_in b/src/ports/postgres/modules/kmeans/kmeans.sql_in
index 81dea80..1eae525 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.sql_in
@@ -999,7 +999,7 @@
sampled_rel_source = MADLIB_SCHEMA.__unique_string();
sampled_col_name = MADLIB_SCHEMA.__unique_string();
IF (seeding_sample_ratio < 1.0) THEN
- EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source;
+ EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source||' CASCADE';
EXECUTE 'CREATE TEMP TABLE '||sampled_rel_source||' AS
SELECT *
FROM
@@ -1059,7 +1059,7 @@
EXECUTE 'SET client_min_messages TO ' || oldClientMinMessages;
IF (seeding_sample_ratio < 1.0) THEN
- EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source;
+ EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source||' CASCADE';
END IF;
RETURN theResult;
@@ -1120,6 +1120,7 @@
min_frac_reassigned
)</pre>
*/
+
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp(
rel_source VARCHAR,
expr_point VARCHAR,
@@ -1733,3 +1734,175 @@
'MADLIB_SCHEMA.dist_norm2')
$$
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
+
+/**
+ * @brief Run auto k-Means.
+ *
+ *
+ */
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */,
+ min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+ seeding_sample_ratio DOUBLE PRECISION /*+ DEFAULT 1.0 */,
+ k_selection_algorithm VARCHAR /*+ DEFAULT 'silhouette' */
+) RETURNS VOID AS $$
+ PythonFunction(`kmeans', `kmeans_auto', `kmeanspp_auto')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */,
+ min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+ seeding_sample_ratio DOUBLE PRECISION /*+ DEFAULT 1.0 */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */,
+ min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[]
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */,
+ min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+ k_selection_algorithm VARCHAR /*+ DEFAULT 'silhouette' */
+) RETURNS VOID AS $$
+ PythonFunction(`kmeans', `kmeans_auto', `kmeans_random_auto')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */,
+ min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, $7, $8, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+ max_num_iterations INTEGER /*+ DEFAULT 20 */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, $7, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+ agg_centroid VARCHAR /*+ DEFAULT 'avg' */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[],
+ fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ expr_point VARCHAR,
+ k INTEGER[]
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
diff --git a/src/ports/postgres/modules/kmeans/kmeans_auto.py_in b/src/ports/postgres/modules/kmeans/kmeans_auto.py_in
new file mode 100644
index 0000000..5eb7d3c
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/kmeans_auto.py_in
@@ -0,0 +1,223 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+@file kmeans_auto.py_in
+
+@brief
+
+"""
+
+import numpy as np
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import output_tbl_valid
+from utilities.validate_args import get_algorithm_name
+
+ELBOW = 'elbow'
+SILHOUETTE = 'silhouette'
+BOTH = 'both'
+
+RANDOM = 'random'
+PP = 'pp'
+
+def _validate(output_table, k):
+
+ output_tbl_valid(output_table, "kmeans_auto")
+ output_tbl_valid('{0}_summary'.format(output_table), "kmeans_auto")
+
+ _assert(k, "kmeans_auto: k cannot be NULL.")
+ _assert(len(k)>1, "kmeans_auto: Length of k array should be more than 1.")
+ _assert(min(k)>1, "kmeans_auto: the minimum k value has to be > 1.")
+ _assert(len(set(k)) == len(k), "kmeans_auto: Duplicate values are not allowed in k.")
+
+
+def set_defaults(schema_madlib, fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned, k_selection_algorithm, seeding, seeding_sample_ratio):
+
+ fn_dist = (fn_dist if fn_dist else '{0}.squared_dist_norm2'.format(schema_madlib))
+ agg_centroid = agg_centroid if agg_centroid \
+ else '{0}.avg'.format(schema_madlib)
+ max_num_iterations = max_num_iterations if max_num_iterations \
+ else 20
+ min_frac_reassigned = min_frac_reassigned if min_frac_reassigned \
+ else 0.001
+
+ k_selection_algorithm = get_algorithm_name(k_selection_algorithm, SILHOUETTE,
+ [ELBOW, SILHOUETTE, BOTH], 'kmeans_auto')
+
+ if seeding is PP:
+ seeding_sample_ratio = (seeding_sample_ratio
+ if seeding_sample_ratio is not None else 1.0)
+ return (fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+ k_selection_algorithm, seeding_sample_ratio)
+
+def kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+ fn_dist=None, agg_centroid=None, max_num_iterations=None,
+ min_frac_reassigned=None, k_selection_algorithm=None, seeding=None,
+ seeding_sample_ratio=None, **kwargs):
+
+ with MinWarning("error"):
+ _validate(output_table, k)
+
+ (fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+ k_selection_algorithm, seeding_sample_ratio) = set_defaults(
+ schema_madlib, fn_dist, agg_centroid, max_num_iterations,
+ min_frac_reassigned, k_selection_algorithm, seeding,
+ seeding_sample_ratio)
+
+ silhouette_col = ""
+ elbow_col = ""
+
+ # If the selection is elbow or both, calculate elbow
+ use_silhouette = k_selection_algorithm in [SILHOUETTE, BOTH]
+ # If the selection is silhouette or both, calculate silhouette
+ use_elbow = k_selection_algorithm in [ELBOW, BOTH]
+
+ if use_silhouette:
+ silhouette_col = ", {0} DOUBLE PRECISION".format(SILHOUETTE)
+ if use_elbow:
+ elbow_col = ", {0} DOUBLE PRECISION".format(ELBOW)
+
+ plpy.execute("""
+ CREATE TABLE {output_table} (
+ k INTEGER,
+ centroids DOUBLE PRECISION[][],
+ cluster_variance DOUBLE PRECISION[],
+ objective_fn DOUBLE PRECISION,
+ frac_reassigned DOUBLE PRECISION,
+ num_iterations INTEGER
+ {silhouette_col}
+ {elbow_col})
+ """.format(**locals()))
+
+ silhouette_vals = []
+
+ for current_k in k:
+ if seeding is 'random':
+ plpy.execute("""
+ INSERT INTO {output_table}
+ (k, centroids, cluster_variance, objective_fn, frac_reassigned,
+ num_iterations)
+ SELECT {current_k} as k, *
+ FROM {schema_madlib}.kmeans_random('{rel_source}',
+ '{expr_point}',
+ {current_k},
+ '{fn_dist}',
+ '{agg_centroid}',
+ {max_num_iterations},
+ {min_frac_reassigned});
+ """.format(**locals()))
+ else:
+ plpy.execute("""
+ INSERT INTO {output_table}
+ (k, centroids, cluster_variance, objective_fn, frac_reassigned,
+ num_iterations)
+ SELECT {current_k} as k, *
+ FROM {schema_madlib}.kmeanspp('{rel_source}',
+ '{expr_point}',
+ {current_k},
+ '{fn_dist}',
+ '{agg_centroid}',
+ {max_num_iterations},
+ {min_frac_reassigned},
+ {seeding_sample_ratio});
+ """.format(**locals()))
+
+ if use_silhouette:
+ silhouette_query= """
+ SELECT * FROM {schema_madlib}.simple_silhouette(
+ '{rel_source}',
+ '{expr_point}',
+ (SELECT centroids
+ FROM {output_table}
+ WHERE k = {current_k}),
+ '{fn_dist}')
+ """.format(**locals())
+ silhouette_vals.append(
+ plpy.execute(silhouette_query)[0]['simple_silhouette'])
+
+ update_query = """
+ UPDATE {output_table} SET {{column}} = __value__ FROM
+ (SELECT unnest(ARRAY[{k_arr}]) AS __k__,
+ unnest(ARRAY[{{calc_arr}}]) AS __value__
+ )sub_q
+ WHERE __k__ = k
+ """.format(output_table = output_table,
+ k_arr = str(k)[1:-1])
+ if use_silhouette:
+ optimal_sil = k[np.argmax(np.array(silhouette_vals))]
+ plpy.execute(update_query.format(column = SILHOUETTE,
+ calc_arr = str(silhouette_vals)[1:-1]))
+
+ if use_elbow:
+ optimal_elbow, second_order = _calculate_elbow(output_table)
+ plpy.execute(update_query.format(column = ELBOW,
+ calc_arr = str(second_order)[1:-1]))
+
+ optimal_k = optimal_sil if use_silhouette else optimal_elbow
+
+ plpy.execute("""
+ CREATE TABLE {output_table}_summary AS
+ SELECT {output_table}.*,
+ '{algorithm}'::VARCHAR AS selection_algorithm
+ FROM {output_table}
+ WHERE k = {optimal_k}
+ """.format(algorithm = SILHOUETTE if use_silhouette else ELBOW,
+ **locals()))
+
+ return
+
+def _calculate_elbow(output_table):
+
+ # We have to get the values in ordered fashion because the elbow is only defined for ordered values.
+ inertia_result = plpy.execute("""
+ SELECT k, objective_fn FROM {output_table} ORDER BY k ASC
+ """.format(**locals()))
+ k = [ i['k'] for i in inertia_result ]
+ inertia_list = [ i['objective_fn'] for i in inertia_result ]
+ inertia_list = np.array(inertia_list)
+
+ first_order=np.gradient(inertia_list, k)
+ second_order=np.gradient(first_order, k)
+ index_with_elbow=k[np.argmax(second_order)]
+
+ return index_with_elbow, second_order.tolist()
+
+def kmeans_random_auto(schema_madlib, rel_source, output_table, expr_point, k,
+ fn_dist=None, agg_centroid=None, max_num_iterations=None,
+ min_frac_reassigned=None, k_selection_algorithm=None, **kwargs):
+
+ kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+ fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+ k_selection_algorithm, RANDOM)
+
+ return
+
+def kmeanspp_auto(schema_madlib, rel_source, output_table, expr_point, k,
+ fn_dist=None, agg_centroid=None, max_num_iterations=None,
+ min_frac_reassigned=None, seeding_sample_ratio=None,
+ k_selection_algorithm=None, **kwargs):
+
+ kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+ fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+ k_selection_algorithm, PP, seeding_sample_ratio)
+
+ return
diff --git a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
index 8d790fa..4553b6c 100644
--- a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
@@ -91,6 +91,121 @@
SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
+-- Test kmeanspp_auto
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'elbow');
+
+SELECT assert(
+ elbow > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_pp failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'silhouette');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_pp failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'both');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_pp failed for both.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8]);
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_pp failed for default.')
+FROM autokm_out_summary;
+
+-- Test kmeans_random_auto
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'elbow');
+
+SELECT assert(
+ elbow > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'El');
+
+SELECT assert(
+ elbow > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[5,6,7,8]);
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for default.')
+FROM autokm_out_summary;
+
+
+SELECT assert(count(*) = 4, 'Kmeans: Auto Kmeans_random output has incorrect number of rows')
+FROM (SELECT * FROM autokm_out WHERE k = any(ARRAY[5,6,7,8]))q;
+
+-- Unordered k list test
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[12,3,5,6,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouette');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for silhouette on unordered k vals')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouetTe');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'siL');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0,
+ 'Kmeans: Auto Kmeans_random failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'both');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0 AND
+ selection_algorithm = 'silhouette',
+ 'Kmeans: Auto Kmeans_random failed for both.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001, 'b');
+
+SELECT assert(
+ silhouette > 0 AND objective_fn > 0 AND
+ selection_algorithm = 'silhouette',
+ 'Kmeans: Auto Kmeans_random failed for both.')
+FROM autokm_out_summary;
+
DROP TABLE IF EXISTS km_sample CASCADE;
DROP TABLE IF EXISTS centroids CASCADE;
DROP TABLE IF EXISTS kmeans_2d CASCADE;
diff --git a/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in b/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in
new file mode 100644
index 0000000..dd18649
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in
@@ -0,0 +1,43 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+def __init__(self):
+ pass
+
+def error(message):
+ raise PLPYException(message)
+
+def execute(query):
+ pass
+
+def warning(query):
+ pass
+
+def info(query):
+ print query
+
+
+class PLPYException(Exception):
+ def __init__(self, message):
+ super(PLPYException, self).__init__()
+ self.message = message
+
+ def __str__(self):
+ return repr(self.message)
diff --git a/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in b/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in
new file mode 100644
index 0000000..b56f4ad
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in
@@ -0,0 +1,87 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import numpy as np
+from os import path
+
+# Add modules to the pythonpath.
+sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
+sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+
+import unittest
+from mock import *
+import plpy_mock as plpy
+
+m4_changequote(`<!', `!>')
+
+class KmeansAutoTestCase(unittest.TestCase):
+ def setUp(self):
+ self.plpy_mock = Mock(spec='error')
+ patches = {
+ 'plpy': plpy,
+ 'utilities.mean_std_dev_calculator': Mock(),
+ }
+ # we need to use MagicMock() instead of Mock() for the plpy.execute mock
+ # to be able to iterate on the return value
+ self.plpy_mock_execute = MagicMock()
+ plpy.execute = self.plpy_mock_execute
+
+ self.module_patcher = patch.dict('sys.modules', patches)
+ self.module_patcher.start()
+
+ self.default_schema_madlib = "madlib"
+ self.default_source_table = "source"
+ self.default_output_table = "output"
+
+ import kmeans_auto
+ self.module = kmeans_auto
+
+
+ def tearDown(self):
+ self.module_patcher.stop()
+
+ def test_calculate_elbow_evenly_spaced(self):
+
+ self.plpy_mock_execute.return_value = [
+ {'k':2, 'objective_fn':100 },
+ {'k':3, 'objective_fn':50 },
+ {'k':4, 'objective_fn':25 },
+ {'k':5, 'objective_fn':20 },
+ {'k':6, 'objective_fn':10 }
+ ]
+ elbow,_ = self.module._calculate_elbow('foo')
+ self.assertEqual(3, elbow)
+
+ def test_calculate_elbow_unevenly_spaced(self):
+
+ self.plpy_mock_execute.return_value = [
+ {'k':2, 'objective_fn':100 },
+ {'k':4, 'objective_fn':80 },
+ {'k':6, 'objective_fn':25 },
+ {'k':7, 'objective_fn':20 },
+ {'k':8, 'objective_fn':10 }
+ ]
+ elbow,_ = self.module._calculate_elbow('foo')
+ self.assertEqual(6, elbow)
+
+if __name__ == '__main__':
+ unittest.main()
+
+# ---------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 6d681e2..eb2150e 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -50,6 +50,7 @@
from utilities.validate_args import is_col_array
from utilities.validate_args import is_var_valid
from utilities.validate_args import quote_ident
+from utilities.validate_args import get_algorithm_name
WEIGHT_FOR_ZERO_DIST = 1e107
BRUTE_FORCE = 'brute_force'
@@ -421,25 +422,6 @@
# ------------------------------------------------------------------------------
-def _get_algorithm_name(algorithm):
- if not algorithm:
- algorithm = BRUTE_FORCE
- else:
- supported_algorithms = [BRUTE_FORCE, KD_TREE]
- try:
- # allow user to specify a prefix substring of
- # supported algorithms. This works because the supported
- # algorithms have unique prefixes.
- algorithm = next(x for x in supported_algorithms
- if x.startswith(algorithm))
- except StopIteration:
- # next() returns a StopIteration if no element found
- plpy.error("kNN Error: Invalid algorithm: "
- "{0}. Supported algorithms are ({1})"
- .format(algorithm, ','.join(sorted(supported_algorithms))))
- return algorithm
-# ------------------------------------------------------------------------------
-
def knn(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name, test_id, output_table,
k, output_neighbors, fn_dist, weighted_avg, algorithm, algorithm_params,
@@ -489,7 +471,8 @@
if k is None:
k = 1
- algorithm = _get_algorithm_name(algorithm)
+ algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
+ [BRUTE_FORCE, KD_TREE], 'kNN')
# Default values for depth and leaf nodes
depth = 3
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
index a3f2539..063d762 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
@@ -49,11 +49,11 @@
def test_input_tbl_valid_null_tbl_raises_exception(self):
with self.assertRaises(plpy.PLPYException):
self.subject.input_tbl_valid(None, "unittest_module")
-
+
def test_input_tbl_valid_whitespaces_tbl_raises(self):
with self.assertRaises(plpy.PLPYException):
self.subject.input_tbl_valid(" ", "unittest_module")
-
+
def test_input_tbl_valid_table_not_exists_raises(self):
self.subject.table_exists = Mock(return_value=False)
with self.assertRaises(plpy.PLPYException):
@@ -113,5 +113,22 @@
self.subject.input_tbl_valid("foo", "unittest_module")
self.assertNotIn('custom exception', str(error.exception))
+ def test_get_algorithm_name(self):
+ self.assertEqual('abc', self.subject.get_algorithm_name(
+ 'abc', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+ self.assertEqual('aaa', self.subject.get_algorithm_name(
+ 'aaa', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+ self.assertEqual('aaa', self.subject.get_algorithm_name(
+ 'aa', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+ self.assertEqual('bcd', self.subject.get_algorithm_name(
+ 'bc', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+
+ # If two options satisfy the given selection,
+ # pick the first one from the list
+ self.assertEqual('aaa', self.subject.get_algorithm_name(
+ 'a', 'abc', ['aaa','abc','bcd'], 'qwerty'))
+ self.assertEqual('aqq', self.subject.get_algorithm_name(
+ 'a', 'abc', ['aqq','abc','bcd'], 'qwerty'))
+
if __name__ == '__main__':
unittest.main()
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index e0758d3..ea4d133 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -734,6 +734,26 @@
', '.join(intersect)))
# -------------------------------------------------------------------------
+
+def get_algorithm_name(algorithm, default, supported_algorithms, module):
+ if not algorithm:
+ algorithm = default
+ else:
+ algorithm = algorithm.lower()
+ try:
+ # allow user to specify a prefix substring of
+ # supported algorithms. This works because the supported
+ # algorithms have unique prefixes.
+ algorithm = next(x for x in supported_algorithms
+ if x.startswith(algorithm))
+ except StopIteration:
+ # next() returns a StopIteration if no element found
+ plpy.error("{0} Error: Invalid algorithm: "
+ "{1}. Supported algorithms are ({2})"
+ .format(module, algorithm,
+ ','.join(sorted(supported_algorithms))))
+ return algorithm
+
import unittest
@@ -749,6 +769,5 @@
self.assertEqual('Test123', unquote_ident('"Test123"'))
self.assertEqual('test', unquote_ident('"test"'))
-
if __name__ == '__main__':
unittest.main()