DBSCAN: Fix predict on Greenplum (#499)
* DBSCAN: Fix predict on Greenplum
The previous dbscan_predict implementation failed on greenplum
if called on a table since the UDFs running on segments cannot
access tables. This implementation creates an output table to
fix this issue.
* Remove outdated example
diff --git a/src/ports/postgres/modules/dbscan/dbscan.py_in b/src/ports/postgres/modules/dbscan/dbscan.py_in
index 7adc971..e8d67a6 100644
--- a/src/ports/postgres/modules/dbscan/dbscan.py_in
+++ b/src/ports/postgres/modules/dbscan/dbscan.py_in
@@ -180,30 +180,37 @@
reachable_points_table))
-def dbscan_predict(schema_madlib, dbscan_table, new_point, **kwargs):
+def dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
+ expr_point, output_table, **kwargs):
with MinWarning("warning"):
+ _validate_dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
+ expr_point, output_table)
+
dbscan_summary_table = add_postfix(dbscan_table, '_summary')
summary = plpy.execute("SELECT * FROM {0}".format(dbscan_summary_table))[0]
eps = summary['eps']
metric = summary['metric']
+ db_id_column = summary['id_column']
sql = """
- SELECT cluster_id,
- {schema_madlib}.{metric}(__points__, ARRAY{new_point}) as dist
- FROM {dbscan_table}
- WHERE is_core_point = TRUE
- ORDER BY dist LIMIT 1
+ CREATE TABLE {output_table} AS
+ SELECT __q1__.{id_column}, cluster_id, distance
+ FROM (
+ SELECT __t2__.{id_column}, cluster_id,
+ min({schema_madlib}.{metric}(__t1__.__points__,
+ __t2__.{expr_point})) as distance
+ FROM {dbscan_table} AS __t1__, {source_table} AS __t2__
+ WHERE is_core_point = TRUE
+ GROUP BY __t2__.{id_column}, cluster_id
+ ) __q1__
+ WHERE distance <= {eps}
""".format(**locals())
- result = plpy.execute(sql)[0]
- dist = result['dist']
- if dist < eps:
- return result['cluster_id']
- else:
- return None
+ result = plpy.execute(sql)
-def _validate_dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm):
+def _validate_dbscan(schema_madlib, source_table, output_table, id_column,
+ expr_point, eps, min_samples, metric, algorithm):
input_tbl_valid(source_table, 'dbscan')
output_tbl_valid(output_table, 'dbscan')
@@ -228,6 +235,25 @@
fn_dist_list = ['dist_norm1', 'dist_norm2', 'squared_dist_norm2', 'dist_angle', 'dist_tanimoto']
_assert(metric in fn_dist_list, "dbscan Error: metric has to be one of the madlib defined distance functions")
+def _validate_dbscan_predict(schema_madlib, dbscan_table, source_table,
+ id_column, expr_point, output_table):
+
+ input_tbl_valid(source_table, 'dbscan')
+ input_tbl_valid(dbscan_table, 'dbscan')
+ dbscan_summary_table = add_postfix(dbscan_table, '_summary')
+ input_tbl_valid(dbscan_summary_table, 'dbscan')
+ output_tbl_valid(output_table, 'dbscan')
+
+ cols_in_tbl_valid(source_table, [id_column], 'dbscan')
+
+ _assert(is_var_valid(source_table, expr_point),
+ "dbscan error: {0} is an invalid column name or "
+ "expression for expr_point param".format(expr_point))
+
+ point_col_type = get_expr_type(expr_point, source_table)
+ _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+ "dbscan Error: Feature column or expression '{0}' in train table is not"
+ " a numeric array.".format(expr_point))
def dbscan_help(schema_madlib, message=None, **kwargs):
"""
diff --git a/src/ports/postgres/modules/dbscan/dbscan.sql_in b/src/ports/postgres/modules/dbscan/dbscan.sql_in
index 7682efc..01d0a55 100644
--- a/src/ports/postgres/modules/dbscan/dbscan.sql_in
+++ b/src/ports/postgres/modules/dbscan/dbscan.sql_in
@@ -179,25 +179,47 @@
<pre class="syntax">
dbscan_predict( dbscan_table,
- new_point
- )
+ source_table,
+ id_column,
+ expr_point,
+ output_table
+ )
</pre>
<b>Arguments</b>
<dl class="arglist">
+
<dt>dbscan_table</dt>
<dd>TEXT. Name of the table created by running DBSCAN.</dd>
-<dt>new_point</dt>
-<dd>DOUBLE PRECISION[]. New points to be assigned to clusters.</dd>
-</dl>
+<dt>source_table</dt>
+<dd>TEXT. Name of the table containing the input data points.
+</dd>
-<b>Output TBD???</b>
+
+<dt>id_column</dt>
+<dd>TEXT. Name of the column containing a unique integer id for each training point.
+</dd>
+
+<dt>expr_point</dt>
+<dd>TEXT. Name of the column with point coordinates in array form,
+or an expression that evaluates to an array of point coordinates.
+</dd>
+
+<dt>output_table</dt>
+<dd>TEXT. Name of the table containing the clustering results.
+</dd>
+
+<b>Output TBD</b>
<br>
-The output is a composite type with the following columns:
+The output is a table with the following columns:
<table class="output">
<tr>
- <th>column_id</th>
+ <th>id_column</th>
+ <td>INTEGER. ID column passed to the function.</td>
+ </tr>
+ <tr>
+ <th>cluster_id</th>
<td>INTEGER. Cluster assignment (zero-based, i.e., 0,1,2...).</td>
</tr>
<tr>
@@ -234,12 +256,24 @@
(18, '{10, 4}'),
(19, '{11, 4}'),
(20, '{10, 3}');
+CREATE TABLE dbscan_test_data (pid int, points double precision[]);
+INSERT INTO dbscan_test_data VALUES
+(1, '{1, 2}'),
+(2, '{2, 2}'),
+(3, '{1, 3}'),
+(4, '{2, 2}'),
+(10, '{5, 11}'),
+(11, '{7, 10}'),
+(12, '{10, 9}'),
+(13, '{10, 6}'),
+(14, '{9, 5}'),
+(15, '{10, 6}');
</pre>
-# Run DBSCAN using the brute force method with a Euclidean
distance function:
<pre class="example">
DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary;
-SELECT madlib.dbscan(
+SELECT madlib.dbscan(
'dbscan_train_data', -- source table
'dbscan_result', -- output table
'pid', -- point id column
@@ -251,7 +285,7 @@
SELECT * FROM dbscan_result ORDER BY pid;
</pre>
<pre class="result">
- pid | cluster_id | is_core_point | __points__
+ pid | cluster_id | is_core_point | __points__
-----+------------+---------------+------------
1 | 0 | t | {1,1}
2 | 0 | t | {2,1}
@@ -274,15 +308,15 @@
</pre>
There are three clusters created. All points are core points
except for 6 and 10 which are border points. The noise points
-do not show up in the output table. If you want to see the noise points
+do not show up in the output table. If you want to see the noise points
you can use a query like:
<pre class="example">
-SELECT l.* FROM dbscan_train_data l WHERE NOT EXISTS
- (SELECT NULL FROM dbscan_result r WHERE r.pid = l.pid)
+SELECT l.* FROM dbscan_train_data l WHERE NOT EXISTS
+ (SELECT NULL FROM dbscan_result r WHERE r.pid = l.pid)
ORDER BY l.pid;
</pre>
<pre class="result">
- pid | points
+ pid | points
-----+--------
5 | {3,5}
11 | {7,10}
@@ -293,17 +327,22 @@
SELECT * FROM dbscan_result_summary;
</pre>
<pre class="result">
- id_column | eps | metric
+ id_column | eps | metric
-----------+------+------------
pid | 1.75 | dist_norm2
</pre>
-# Find the cluster assignment. In this example we use the same source
points for demonstration purposes:
<pre class="example">
-SELECT pid, madlib.dbscan_predict (
- 'dbscan_result', -- from DBSCAN run
- points) -- data to cluster
-FROM dbscan_train_data ORDER BY pid;
+
+SELECT madlib.dbscan_predict(
+ 'dbscan_result', -- from DBSCAN run
+ 'dbscan_test_data', -- test dataset
+ 'pid', -- point id column
+ 'points', -- data point
+ 'dbscan_predict_out' -- output table
+ );
+
</pre>
<pre class="result">
TBD???
@@ -312,7 +351,7 @@
distance function:
<pre class="example">
DROP TABLE IF EXISTS dbscan_result_kd, dbscan_result_kd_summary;
-SELECT madlib.dbscan(
+SELECT madlib.dbscan(
'dbscan_train_data', -- source table
'dbscan_result_kd', -- output table
'pid', -- point id column
@@ -391,8 +430,11 @@
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.dbscan_predict(
dbscan_table VARCHAR,
- new_point DOUBLE PRECISION[]
-) RETURNS INTEGER AS $$
+ source_table VARCHAR,
+ id_column VARCHAR,
+ expr_point VARCHAR,
+ output_table VARCHAR
+) RETURNS VOID AS $$
PythonFunction(dbscan, dbscan, dbscan_predict)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/dbscan/test/dbscan.sql_in b/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
index 54fcddf..621251a 100644
--- a/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
+++ b/src/ports/postgres/modules/dbscan/test/dbscan.sql_in
@@ -40,16 +40,16 @@
13|{8,113}
\.
-DROP TABLE IF EXISTS out1, out1_summary;
+DROP TABLE IF EXISTS out1, out1_summary, out1_predict;
SELECT dbscan('dbscan_train_data','out1','id_in','data',20,4,'squared_dist_norm2','brute');
SELECT assert(count(DISTINCT id_in) = 5, 'Incorrect cluster 0') FROM out1 WHERE cluster_id = 0 and id_in=ANY(ARRAY[1,2,3,4,5]);
SELECT assert(count(DISTINCT id_in) = 4, 'Incorrect cluster 1') FROM out1 WHERE cluster_id = 1 and id_in=ANY(ARRAY[6,7,8,9]);
-SELECT assert(dbscan_predict('out1', array[0,0]::double precision[]) = 0, 'Incorrect predict 0');
-SELECT assert(dbscan_predict('out1', array[9.1,10.8]::double precision[]) = 1, 'Incorrect predict 1');
-SELECT assert(dbscan_predict('out1', array[9,113]::double precision[]) IS NULL, 'Incorrect predict NULL');
+SELECT dbscan_predict('out1', 'dbscan_train_data', 'id_in', 'data', 'out1_predict');
+
+SELECT assert(count(DISTINCT cluster_id) = 2, 'Incorrect cluster count') FROM out1_predict;
DROP TABLE IF EXISTS dbscan_train_data2;
CREATE TABLE dbscan_train_data2 (pid int, points double precision[]);
@@ -75,7 +75,21 @@
(19, '{11, 4}'),
(20, '{10, 3}');
-DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary;
+DROP TABLE IF EXISTS dbscan_test_data2;
+CREATE TABLE dbscan_test_data2 (pid int, points double precision[]);
+INSERT INTO dbscan_test_data2 VALUES
+(1, '{1, 2}'),
+(2, '{2, 2}'),
+(3, '{1, 3}'),
+(4, '{2, 2}'),
+(10, '{5, 11}'),
+(11, '{7, 10}'),
+(12, '{10, 9}'),
+(13, '{10, 6}'),
+(14, '{9, 5}'),
+(15, '{10, 6}');
+
+DROP TABLE IF EXISTS dbscan_result, dbscan_result_summary, dbscan_predict_out;
SELECT dbscan(
'dbscan_train_data2', -- source table
'dbscan_result', -- output table
@@ -89,3 +103,7 @@
SELECT * FROM dbscan_result ORDER BY pid;
SELECT assert(count(DISTINCT cluster_id) = 3, 'Incorrect cluster count') FROM dbscan_result;
+
+SELECT dbscan_predict('dbscan_result', 'dbscan_test_data2', 'pid', 'points', 'dbscan_predict_out');
+
+SELECT assert(count(DISTINCT cluster_id) = 3, 'Incorrect cluster count') FROM dbscan_predict_out;