DBSCAN: Fix predict on Greenplum (#499)

* DBSCAN: Fix predict on Greenplum

The previous dbscan_predict implementation failed on greenplum
if called on a table since the UDFs running on segments cannot
access tables. This implementation creates an output table to
fix this issue.

* Remove outdated example
diff --git a/src/ports/postgres/modules/dbscan/dbscan.py_in b/src/ports/postgres/modules/dbscan/dbscan.py_in
index 7adc971..e8d67a6 100644
--- a/src/ports/postgres/modules/dbscan/dbscan.py_in
+++ b/src/ports/postgres/modules/dbscan/dbscan.py_in
@@ -180,30 +180,37 @@
                      reachable_points_table))
 
 
-def dbscan_predict(schema_madlib, dbscan_table, new_point, **kwargs):
+def dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
+    expr_point, output_table, **kwargs):
 
     with MinWarning("warning"):
 
+        _validate_dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
+    expr_point, output_table)
+
         dbscan_summary_table = add_postfix(dbscan_table, '_summary')
         summary = plpy.execute("SELECT * FROM {0}".format(dbscan_summary_table))[0]
 
         eps = summary['eps']
         metric = summary['metric']
+        db_id_column = summary['id_column']
         sql = """
-            SELECT cluster_id,
-                   {schema_madlib}.{metric}(__points__, ARRAY{new_point}) as dist
-            FROM {dbscan_table}
-            WHERE is_core_point = TRUE
-            ORDER BY dist LIMIT 1
+            CREATE TABLE {output_table} AS
+            SELECT __q1__.{id_column}, cluster_id, distance
+            FROM (
+                SELECT __t2__.{id_column}, cluster_id,
+                       min({schema_madlib}.{metric}(__t1__.__points__,
+                                                __t2__.{expr_point})) as distance
+                FROM {dbscan_table} AS __t1__, {source_table} AS __t2__
+                WHERE is_core_point = TRUE
+                GROUP BY __t2__.{id_column}, cluster_id
+                ) __q1__
+            WHERE distance <= {eps}
             """.format(**locals())
-        result = plpy.execute(sql)[0]
-        dist = result['dist']
-        if dist < eps:
-            return result['cluster_id']
-        else:
-            return None
+        result = plpy.execute(sql)
 
-def _validate_dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm):
+def _validate_dbscan(schema_madlib, source_table, output_table, id_column,
+    expr_point, eps, min_samples, metric, algorithm):
 
     input_tbl_valid(source_table, 'dbscan')
     output_tbl_valid(output_table, 'dbscan')
@@ -228,6 +235,25 @@
     fn_dist_list = ['dist_norm1', 'dist_norm2', 'squared_dist_norm2', 'dist_angle', 'dist_tanimoto']
     _assert(metric in fn_dist_list, "dbscan Error: metric has to be one of the madlib defined distance functions")
 
+def _validate_dbscan_predict(schema_madlib, dbscan_table, source_table,
+    id_column, expr_point, output_table):
+
+    input_tbl_valid(source_table, 'dbscan')
+    input_tbl_valid(dbscan_table, 'dbscan')
+    dbscan_summary_table = add_postfix(dbscan_table, '_summary')
+    input_tbl_valid(dbscan_summary_table, 'dbscan')
+    output_tbl_valid(output_table, 'dbscan')
+
+    cols_in_tbl_valid(source_table, [id_column], 'dbscan')
+
+    _assert(is_var_valid(source_table, expr_point),
+            "dbscan error: {0} is an invalid column name or "
+            "expression for expr_point param".format(expr_point))
+
+    point_col_type = get_expr_type(expr_point, source_table)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "dbscan Error: Feature column or expression '{0}' in train table is not"
+            " a numeric array.".format(expr_point))
 
 def dbscan_help(schema_madlib, message=None, **kwargs):
     """
diff --git a/src/ports/postgres/modules/dbscan/dbscan.sql_in b/src/ports/postgres/modules/dbscan/dbscan.sql_in
index 7682efc..01d0a55 100644
--- a/src/ports/postgres/modules/dbscan/dbscan.sql_in
+++ b/src/ports/postgres/modules/dbscan/dbscan.sql_in
@@ -179,25 +179,47 @@
 
 <pre class="syntax">
 dbscan_predict( dbscan_table,
-                new_point
-              )
+                source_table,
+                id_column,
+                expr_point,
+                output_table
+                )
 </pre>
 
 <b>Arguments</b>
 <dl class="arglist">
+
 <dt>dbscan_table</dt>
 <dd>TEXT. Name of the table created by running DBSCAN.</dd>
 
-<dt>new_point</dt>
-<dd>DOUBLE PRECISION[]. New points to be assigned to clusters.</dd>
-</dl>
+<dt>source_table</dt>
+<dd>TEXT. Name of the table containing the input data points.
+</dd>
 
-<b>Output TBD???</b>
+
+<dt>id_column</dt>
+<dd>TEXT. Name of the column containing a unique integer id for each training point.
+</dd>
+
+<dt>expr_point</dt>
+<dd>TEXT. Name of the column with point coordinates in array form,
+or an expression that evaluates to an array of point coordinates.
+</dd>
+
+<dt>output_table</dt>
+<dd>TEXT. Name of the table containing the clustering results.
+</dd>
+
+<b>Output TBD</b>
 <br>
-The output is a composite type with the following columns:
+The output is a table with the following columns:
 <table class="output">
     <tr>
-      <th>column_id</th>
+      <th>id_column</th>
+      <td>INTEGER. ID column passed to the function.</td>
+    </tr>
+    <tr>
+      <th>cluster_id</th>
       <td>INTEGER. Cluster assignment (zero-based, i.e., 0,1,2...).</td>
     </tr>
     <tr>
@@ -234,12 +256,24 @@
 (18,  '{10, 4}'),
 (19,  '{11, 4}'),
 (20,  '{10, 3}');
+CREATE TABLE dbscan_test_data (pid int, points double precision[]);
+INSERT INTO dbscan_test_data VALUES
+(1,  '{1, 2}'),
+(2,  '{2, 2}'),
+(3,  '{1, 3}'),
+(4,  '{2, 2}'),
+(10,  '{5, 11}'),
+(11,  '{7, 10}'),
+(12,  '{10, 9}'),
+(13,  '{10, 6}'),
+(14,  '{9, 5}'),
+(15,  '{10, 6}');
 </pre>
 -#  Run DBSCAN using the brute force method with a Euclidean
 distance function:
 <pre class="example">
 DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary;
-SELECT madlib.dbscan( 
+SELECT madlib.dbscan(
                 'dbscan_train_data',    -- source table
                 'dbscan_result',        -- output table
                 'pid',                  -- point id column
@@ -251,7 +285,7 @@
 SELECT * FROM dbscan_result ORDER BY pid;
 </pre>
 <pre class="result">
- pid | cluster_id | is_core_point | __points__ 
+ pid | cluster_id | is_core_point | __points__
 -----+------------+---------------+------------
    1 |          0 | t             | {1,1}
    2 |          0 | t             | {2,1}
@@ -274,15 +308,15 @@
 </pre>
 There are three clusters created.  All points are core points
 except for 6 and 10 which are border points.  The noise points
-do not show up in the output table. If you want to see the noise points 
+do not show up in the output table. If you want to see the noise points
 you can use a query like:
 <pre class="example">
-SELECT l.* FROM dbscan_train_data l WHERE NOT EXISTS 
-    (SELECT NULL FROM dbscan_result r WHERE r.pid = l.pid) 
+SELECT l.* FROM dbscan_train_data l WHERE NOT EXISTS
+    (SELECT NULL FROM dbscan_result r WHERE r.pid = l.pid)
     ORDER BY l.pid;
 </pre>
 <pre class="result">
- pid | points 
+ pid | points
 -----+--------
    5 | {3,5}
   11 | {7,10}
@@ -293,17 +327,22 @@
 SELECT * FROM dbscan_result_summary;
 </pre>
 <pre class="result">
- id_column | eps  |   metric   
+ id_column | eps  |   metric
 -----------+------+------------
  pid       | 1.75 | dist_norm2
 </pre>
 -#  Find the cluster assignment.  In this example we use the same source
 points for demonstration purposes:
 <pre class="example">
-SELECT pid, madlib.dbscan_predict (
-                        'dbscan_result',   -- from DBSCAN run
-                         points)           -- data to cluster
-FROM dbscan_train_data ORDER BY pid;
+
+SELECT madlib.dbscan_predict(
+                        'dbscan_result',        -- from DBSCAN run
+                        'dbscan_test_data',     -- test dataset
+                        'pid',                  -- point id column
+                        'points',               -- data point
+                        'dbscan_predict_out'    -- output table
+                        );
+
 </pre>
 <pre class="result">
 TBD???
@@ -312,7 +351,7 @@
 distance function:
 <pre class="example">
 DROP TABLE IF EXISTS dbscan_result_kd, dbscan_result_kd_summary;
-SELECT madlib.dbscan( 
+SELECT madlib.dbscan(
                 'dbscan_train_data',    -- source table
                 'dbscan_result_kd',     -- output table
                 'pid',                  -- point id column
@@ -391,8 +430,11 @@
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.dbscan_predict(
     dbscan_table                VARCHAR,
-    new_point                   DOUBLE PRECISION[]
-) RETURNS INTEGER AS $$
+    source_table                VARCHAR,
+    id_column                   VARCHAR,
+    expr_point                  VARCHAR,
+    output_table                VARCHAR
+) RETURNS VOID AS $$
     PythonFunction(dbscan, dbscan, dbscan_predict)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/dbscan/test/dbscan.sql_in b/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
index 54fcddf..621251a 100644
--- a/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
+++ b/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
@@ -40,16 +40,16 @@
 13|{8,113}
 \.
 
-DROP TABLE IF EXISTS out1, out1_summary;
+DROP TABLE IF EXISTS out1, out1_summary, out1_predict;
 SELECT dbscan('dbscan_train_data','out1','id_in','data',20,4,'squared_dist_norm2','brute');
 
 SELECT assert(count(DISTINCT id_in) = 5, 'Incorrect cluster 0') FROM out1 WHERE cluster_id = 0 and id_in=ANY(ARRAY[1,2,3,4,5]);
 
 SELECT assert(count(DISTINCT id_in) = 4, 'Incorrect cluster 1') FROM out1 WHERE cluster_id = 1 and id_in=ANY(ARRAY[6,7,8,9]);
 
-SELECT assert(dbscan_predict('out1', array[0,0]::double precision[]) = 0, 'Incorrect predict 0');
-SELECT assert(dbscan_predict('out1', array[9.1,10.8]::double precision[]) = 1, 'Incorrect predict 1');
-SELECT assert(dbscan_predict('out1', array[9,113]::double precision[]) IS NULL, 'Incorrect predict NULL');
+SELECT dbscan_predict('out1', 'dbscan_train_data', 'id_in', 'data', 'out1_predict');
+
+SELECT assert(count(DISTINCT cluster_id) = 2, 'Incorrect cluster count') FROM out1_predict;
 
 DROP TABLE IF EXISTS dbscan_train_data2;
 CREATE TABLE dbscan_train_data2 (pid int, points double precision[]);
@@ -75,7 +75,21 @@
 (19,  '{11, 4}'),
 (20,  '{10, 3}');
 
-DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary;
+DROP TABLE IF EXISTS dbscan_test_data2;
+CREATE TABLE dbscan_test_data2 (pid int, points double precision[]);
+INSERT INTO dbscan_test_data2 VALUES
+(1,  '{1, 2}'),
+(2,  '{2, 2}'),
+(3,  '{1, 3}'),
+(4,  '{2, 2}'),
+(10,  '{5, 11}'),
+(11,  '{7, 10}'),
+(12,  '{10, 9}'),
+(13,  '{10, 6}'),
+(14,  '{9, 5}'),
+(15,  '{10, 6}');
+
+DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary, dbscan_predict_out;
 SELECT dbscan(
 'dbscan_train_data2',    -- source table
 'dbscan_result',        -- output table
@@ -89,3 +103,7 @@
 SELECT * FROM dbscan_result ORDER BY pid;
 
 SELECT assert(count(DISTINCT cluster_id) = 3, 'Incorrect cluster count') FROM dbscan_result;
+
+SELECT dbscan_predict('dbscan_result', 'dbscan_test_data2', 'pid', 'points', 'dbscan_predict_out');
+
+SELECT assert(count(DISTINCT cluster_id) = 3, 'Incorrect cluster count') FROM dbscan_predict_out;