blob: e8d67a6812e3927173fc0d1f002e6fda0ff3b554 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import NUMERIC, ONLY_ARRAY
from utilities.utilities import is_valid_psql_type
from utilities.utilities import is_platform_pg
from utilities.validate_args import input_tbl_valid, output_tbl_valid
from utilities.validate_args import is_var_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import get_expr_type
from utilities.validate_args import get_algorithm_name
from graph.wcc import wcc
BRUTE_FORCE = 'brute_force'
KD_TREE = 'kd_tree'
def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs):
with MinWarning("warning"):
min_samples = 5 if not min_samples else min_samples
metric = 'squared_dist_norm2' if not metric else metric
algorithm = 'brute' if not algorithm else algorithm
algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
[BRUTE_FORCE, KD_TREE], 'DBSCAN')
_validate_dbscan(schema_madlib, source_table, output_table, id_column,
expr_point, eps, min_samples, metric, algorithm)
dist_src_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__src__)'
dist_id_sql = '' if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column)
dist_reach_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)'
# Calculate pairwise distances
distance_table = unique_string(desp='distance_table')
plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table))
sql = """
CREATE TABLE {distance_table} AS
SELECT __src__, __dest__ FROM (
SELECT __t1__.{id_column} AS __src__,
__t2__.{id_column} AS __dest__,
{schema_madlib}.{metric}(
__t1__.{expr_point}, __t2__.{expr_point}) AS __dist__
FROM {source_table} AS __t1__, {source_table} AS __t2__
WHERE __t1__.{id_column} != __t2__.{id_column}) q1
WHERE __dist__ < {eps}
{dist_src_sql}
""".format(**locals())
plpy.execute(sql)
# Find core points
# We use __count__ + 1 because we have to add the point itself to the count
core_points_table = unique_string(desp='core_points_table')
plpy.execute("DROP TABLE IF EXISTS {0}".format(core_points_table))
sql = """
CREATE TABLE {core_points_table} AS
SELECT * FROM (SELECT __src__ AS {id_column}, count(*) AS __count__
FROM {distance_table} GROUP BY __src__) q1
WHERE __count__ + 1 >= {min_samples}
{dist_id_sql}
""".format(**locals())
plpy.execute(sql)
# Find the connections between core points to form the clusters
core_edge_table = unique_string(desp='core_edge_table')
plpy.execute("DROP TABLE IF EXISTS {0}".format(core_edge_table))
sql = """
CREATE TABLE {core_edge_table} AS
SELECT __src__, __dest__
FROM {distance_table} AS __t1__, (SELECT array_agg({id_column}) AS arr
FROM {core_points_table}) __t2__
WHERE __t1__.__src__ = ANY(arr) AND __t1__.__dest__ = ANY(arr)
{dist_src_sql}
""".format(**locals())
plpy.execute(sql)
# Run wcc to get the min id for each cluster
wcc(schema_madlib, core_points_table, id_column, core_edge_table, 'src=__src__, dest=__dest__',
output_table, None)
plpy.execute("""
ALTER TABLE {0}
ADD COLUMN is_core_point BOOLEAN,
ADD COLUMN __points__ DOUBLE PRECISION[]
""".format(output_table))
plpy.execute("""
ALTER TABLE {0}
RENAME COLUMN component_id TO cluster_id
""".format(output_table))
plpy.execute("""
UPDATE {0}
SET is_core_point = TRUE
""".format(output_table))
# Find reachable points
reachable_points_table = unique_string(desp='reachable_points_table')
plpy.execute("DROP TABLE IF EXISTS {0}".format(reachable_points_table))
sql = """
CREATE TABLE {reachable_points_table} AS
SELECT array_agg(__src__) AS __src_list__,
__dest__ AS __reachable_id__
FROM {distance_table} AS __t1__,
(SELECT array_agg({id_column}) AS __arr__
FROM {core_points_table}) __t2__
WHERE __src__ = ANY(__arr__) AND __dest__ != ALL(__arr__)
GROUP BY __dest__
{dist_reach_sql}
""".format(**locals())
plpy.execute(sql)
sql = """
INSERT INTO {output_table}
SELECT __reachable_id__ as {id_column},
cluster_id,
FALSE AS is_core_point,
NULL AS __points__
FROM {reachable_points_table} AS __t1__ INNER JOIN
{output_table} AS __t2__
ON (__src_list__[1] = {id_column})
""".format(**locals())
plpy.execute(sql)
# Add features of points to the output table to use them for prediction
sql = """
UPDATE {output_table} AS __t1__
SET __points__ = {expr_point}
FROM {source_table} AS __t2__
WHERE __t1__.{id_column} = __t2__.{id_column}
""".format(**locals())
plpy.execute(sql)
# Update the cluster ids to be consecutive
sql = """
UPDATE {output_table} AS __t1__
SET cluster_id = new_id-1
FROM (
SELECT cluster_id, row_number() OVER(ORDER BY cluster_id) AS new_id
FROM {output_table}
GROUP BY cluster_id) __t2__
WHERE __t1__.cluster_id = __t2__.cluster_id
""".format(**locals())
plpy.execute(sql)
output_summary_table = add_postfix(output_table, '_summary')
plpy.execute("DROP TABLE IF EXISTS {0}".format(output_summary_table))
sql = """
CREATE TABLE {output_summary_table} AS
SELECT '{id_column}'::VARCHAR AS id_column,
{eps}::DOUBLE PRECISION AS eps,
'{metric}'::VARCHAR AS metric
""".format(**locals())
plpy.execute(sql)
plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}, {3}".format(
distance_table, core_points_table, core_edge_table,
reachable_points_table))
def dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
expr_point, output_table, **kwargs):
with MinWarning("warning"):
_validate_dbscan_predict(schema_madlib, dbscan_table, source_table, id_column,
expr_point, output_table)
dbscan_summary_table = add_postfix(dbscan_table, '_summary')
summary = plpy.execute("SELECT * FROM {0}".format(dbscan_summary_table))[0]
eps = summary['eps']
metric = summary['metric']
db_id_column = summary['id_column']
sql = """
CREATE TABLE {output_table} AS
SELECT __q1__.{id_column}, cluster_id, distance
FROM (
SELECT __t2__.{id_column}, cluster_id,
min({schema_madlib}.{metric}(__t1__.__points__,
__t2__.{expr_point})) as distance
FROM {dbscan_table} AS __t1__, {source_table} AS __t2__
WHERE is_core_point = TRUE
GROUP BY __t2__.{id_column}, cluster_id
) __q1__
WHERE distance <= {eps}
""".format(**locals())
result = plpy.execute(sql)
def _validate_dbscan(schema_madlib, source_table, output_table, id_column,
expr_point, eps, min_samples, metric, algorithm):
input_tbl_valid(source_table, 'dbscan')
output_tbl_valid(output_table, 'dbscan')
output_summary_table = add_postfix(output_table, '_summary')
output_tbl_valid(output_summary_table, 'dbscan')
cols_in_tbl_valid(source_table, [id_column], 'dbscan')
_assert(is_var_valid(source_table, expr_point),
"dbscan error: {0} is an invalid column name or "
"expression for expr_point param".format(expr_point))
point_col_type = get_expr_type(expr_point, source_table)
_assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
"dbscan Error: Feature column or expression '{0}' in train table is not"
" a numeric array.".format(expr_point))
_assert(eps > 0, "dbscan Error: eps has to be a positive number")
_assert(min_samples > 0, "dbscan Error: min_samples has to be a positive number")
fn_dist_list = ['dist_norm1', 'dist_norm2', 'squared_dist_norm2', 'dist_angle', 'dist_tanimoto']
_assert(metric in fn_dist_list, "dbscan Error: metric has to be one of the madlib defined distance functions")
def _validate_dbscan_predict(schema_madlib, dbscan_table, source_table,
id_column, expr_point, output_table):
input_tbl_valid(source_table, 'dbscan')
input_tbl_valid(dbscan_table, 'dbscan')
dbscan_summary_table = add_postfix(dbscan_table, '_summary')
input_tbl_valid(dbscan_summary_table, 'dbscan')
output_tbl_valid(output_table, 'dbscan')
cols_in_tbl_valid(source_table, [id_column], 'dbscan')
_assert(is_var_valid(source_table, expr_point),
"dbscan error: {0} is an invalid column name or "
"expression for expr_point param".format(expr_point))
point_col_type = get_expr_type(expr_point, source_table)
_assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
"dbscan Error: Feature column or expression '{0}' in train table is not"
" a numeric array.".format(expr_point))
def dbscan_help(schema_madlib, message=None, **kwargs):
"""
Help function for dbscan
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if message is not None and \
message.lower() in ("usage", "help", "?"):
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.dbscan(
source_table, -- Name of the training data table
output_table, -- Name of the output table
id_column, -- Name of id column in source_table
expr_point, -- Column name or expression for data points
eps, -- The minimum radius of a cluster
min_samples, -- The minimum size of a cluster
metric, -- The name of the function to use to calculate the
-- distance
algorithm -- The algorithm to use for dbscan.
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output of the dbscan function is a table with the following columns:
id_column The ids of test data point
cluster_id The id of the points associated cluster
is_core_point Boolean column that indicates if the point is core or not
points The column or expression for the data point
"""
else:
help_string = """
----------------------------------------------------------------------------
SUMMARY
----------------------------------------------------------------------------
DBSCAN is a density-based clustering algorithm. Given a set of points in
some space, it groups together points that are closely packed together
(points with many nearby neighbors), marking as outliers points that lie
alone in low-density regions (whose nearest neighbors are too far away).
--
For an overview on usage, run:
SELECT {schema_madlib}.dbscan('usage');
SELECT {schema_madlib}.dbscan_predict('usage');
"""
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------------
def dbscan_predict_help(schema_madlib, message=None, **kwargs):
"""
Help function for dbscan
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if message is not None and \
message.lower() in ("usage", "help", "?"):
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.dbscan_predict(
dbscan_table, -- Name of the tdbscan output table
new_point -- Double precision array representing the point
-- for prediction
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output of the dbscan_predict is an integer indicating the cluster_id
of given point
"""
else:
help_string = """
----------------------------------------------------------------------------
SUMMARY
----------------------------------------------------------------------------
DBSCAN is a density-based clustering algorithm. Given a set of points in
some space, it groups together points that are closely packed together
(points with many nearby neighbors), marking as outliers points that lie
alone in low-density regions (whose nearest neighbors are too far away).
--
For an overview on usage, run:
SELECT {schema_madlib}.dbscan('usage');
SELECT {schema_madlib}.dbscan_predict('usage');
"""
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------------