| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| |
| from utilities.control import MinWarning |
| from utilities.utilities import _assert |
| from utilities.utilities import unique_string |
| from utilities.utilities import add_postfix |
| from utilities.utilities import NUMERIC, ONLY_ARRAY |
| from utilities.utilities import is_valid_psql_type |
| from utilities.utilities import is_platform_pg |
| from utilities.validate_args import input_tbl_valid, output_tbl_valid |
| from utilities.validate_args import is_var_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import get_algorithm_name |
| from graph.wcc import wcc |
| |
| BRUTE_FORCE = 'brute_force' |
| KD_TREE = 'kd_tree' |
| |
| def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs): |
| |
| with MinWarning("warning"): |
| |
| min_samples = 5 if not min_samples else min_samples |
| metric = 'squared_dist_norm2' if not metric else metric |
| algorithm = 'brute' if not algorithm else algorithm |
| |
| algorithm = get_algorithm_name(algorithm, BRUTE_FORCE, |
| [BRUTE_FORCE, KD_TREE], 'DBSCAN') |
| |
| _validate_dbscan(schema_madlib, source_table, output_table, id_column, |
| expr_point, eps, min_samples, metric, algorithm) |
| |
| dist_src_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__src__)' |
| dist_id_sql = '' if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column) |
| dist_reach_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)' |
| |
| # Calculate pairwise distances |
| distance_table = unique_string(desp='distance_table') |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table)) |
| |
| sql = """ |
| CREATE TABLE {distance_table} AS |
| SELECT __src__, __dest__ FROM ( |
| SELECT __t1__.{id_column} AS __src__, |
| __t2__.{id_column} AS __dest__, |
| {schema_madlib}.{metric}( |
| __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__ |
| FROM {source_table} AS __t1__, {source_table} AS __t2__ |
| WHERE __t1__.{id_column} != __t2__.{id_column}) q1 |
| WHERE __dist__ < {eps} |
| {dist_src_sql} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # Find core points |
| # We use __count__ + 1 because we have to add the point itself to the count |
| core_points_table = unique_string(desp='core_points_table') |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(core_points_table)) |
| sql = """ |
| CREATE TABLE {core_points_table} AS |
| SELECT * FROM (SELECT __src__ AS {id_column}, count(*) AS __count__ |
| FROM {distance_table} GROUP BY __src__) q1 |
| WHERE __count__ + 1 >= {min_samples} |
| {dist_id_sql} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # Find the connections between core points to form the clusters |
| core_edge_table = unique_string(desp='core_edge_table') |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(core_edge_table)) |
| sql = """ |
| CREATE TABLE {core_edge_table} AS |
| SELECT __src__, __dest__ |
| FROM {distance_table} AS __t1__, (SELECT array_agg({id_column}) AS arr |
| FROM {core_points_table}) __t2__ |
| WHERE __t1__.__src__ = ANY(arr) AND __t1__.__dest__ = ANY(arr) |
| {dist_src_sql} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # Run wcc to get the min id for each cluster |
| wcc(schema_madlib, core_points_table, id_column, core_edge_table, 'src=__src__, dest=__dest__', |
| output_table, None) |
| plpy.execute(""" |
| ALTER TABLE {0} |
| ADD COLUMN is_core_point BOOLEAN, |
| ADD COLUMN __points__ DOUBLE PRECISION[] |
| """.format(output_table)) |
| plpy.execute(""" |
| ALTER TABLE {0} |
| RENAME COLUMN component_id TO cluster_id |
| """.format(output_table)) |
| plpy.execute(""" |
| UPDATE {0} |
| SET is_core_point = TRUE |
| """.format(output_table)) |
| |
| # Find reachable points |
| reachable_points_table = unique_string(desp='reachable_points_table') |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(reachable_points_table)) |
| sql = """ |
| CREATE TABLE {reachable_points_table} AS |
| SELECT array_agg(__src__) AS __src_list__, |
| __dest__ AS __reachable_id__ |
| FROM {distance_table} AS __t1__, |
| (SELECT array_agg({id_column}) AS __arr__ |
| FROM {core_points_table}) __t2__ |
| WHERE __src__ = ANY(__arr__) AND __dest__ != ALL(__arr__) |
| GROUP BY __dest__ |
| {dist_reach_sql} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| sql = """ |
| INSERT INTO {output_table} |
| SELECT __reachable_id__ as {id_column}, |
| cluster_id, |
| FALSE AS is_core_point, |
| NULL AS __points__ |
| FROM {reachable_points_table} AS __t1__ INNER JOIN |
| {output_table} AS __t2__ |
| ON (__src_list__[1] = {id_column}) |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # Add features of points to the output table to use them for prediction |
| sql = """ |
| UPDATE {output_table} AS __t1__ |
| SET __points__ = {expr_point} |
| FROM {source_table} AS __t2__ |
| WHERE __t1__.{id_column} = __t2__.{id_column} |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| # Update the cluster ids to be consecutive |
| sql = """ |
| UPDATE {output_table} AS __t1__ |
| SET cluster_id = new_id-1 |
| FROM ( |
| SELECT cluster_id, row_number() OVER(ORDER BY cluster_id) AS new_id |
| FROM {output_table} |
| GROUP BY cluster_id) __t2__ |
| WHERE __t1__.cluster_id = __t2__.cluster_id |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| output_summary_table = add_postfix(output_table, '_summary') |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(output_summary_table)) |
| |
| sql = """ |
| CREATE TABLE {output_summary_table} AS |
| SELECT '{id_column}'::VARCHAR AS id_column, |
| {eps}::DOUBLE PRECISION AS eps, |
| '{metric}'::VARCHAR AS metric |
| """.format(**locals()) |
| plpy.execute(sql) |
| |
| plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}, {3}".format( |
| distance_table, core_points_table, core_edge_table, |
| reachable_points_table)) |
| |
| |
| def dbscan_predict(schema_madlib, dbscan_table, source_table, id_column, |
| expr_point, output_table, **kwargs): |
| |
| with MinWarning("warning"): |
| |
| _validate_dbscan_predict(schema_madlib, dbscan_table, source_table, id_column, |
| expr_point, output_table) |
| |
| dbscan_summary_table = add_postfix(dbscan_table, '_summary') |
| summary = plpy.execute("SELECT * FROM {0}".format(dbscan_summary_table))[0] |
| |
| eps = summary['eps'] |
| metric = summary['metric'] |
| db_id_column = summary['id_column'] |
| sql = """ |
| CREATE TABLE {output_table} AS |
| SELECT __q1__.{id_column}, cluster_id, distance |
| FROM ( |
| SELECT __t2__.{id_column}, cluster_id, |
| min({schema_madlib}.{metric}(__t1__.__points__, |
| __t2__.{expr_point})) as distance |
| FROM {dbscan_table} AS __t1__, {source_table} AS __t2__ |
| WHERE is_core_point = TRUE |
| GROUP BY __t2__.{id_column}, cluster_id |
| ) __q1__ |
| WHERE distance <= {eps} |
| """.format(**locals()) |
| result = plpy.execute(sql) |
| |
| def _validate_dbscan(schema_madlib, source_table, output_table, id_column, |
| expr_point, eps, min_samples, metric, algorithm): |
| |
| input_tbl_valid(source_table, 'dbscan') |
| output_tbl_valid(output_table, 'dbscan') |
| output_summary_table = add_postfix(output_table, '_summary') |
| output_tbl_valid(output_summary_table, 'dbscan') |
| |
| cols_in_tbl_valid(source_table, [id_column], 'dbscan') |
| |
| _assert(is_var_valid(source_table, expr_point), |
| "dbscan error: {0} is an invalid column name or " |
| "expression for expr_point param".format(expr_point)) |
| |
| point_col_type = get_expr_type(expr_point, source_table) |
| _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), |
| "dbscan Error: Feature column or expression '{0}' in train table is not" |
| " a numeric array.".format(expr_point)) |
| |
| _assert(eps > 0, "dbscan Error: eps has to be a positive number") |
| |
| _assert(min_samples > 0, "dbscan Error: min_samples has to be a positive number") |
| |
| fn_dist_list = ['dist_norm1', 'dist_norm2', 'squared_dist_norm2', 'dist_angle', 'dist_tanimoto'] |
| _assert(metric in fn_dist_list, "dbscan Error: metric has to be one of the madlib defined distance functions") |
| |
| def _validate_dbscan_predict(schema_madlib, dbscan_table, source_table, |
| id_column, expr_point, output_table): |
| |
| input_tbl_valid(source_table, 'dbscan') |
| input_tbl_valid(dbscan_table, 'dbscan') |
| dbscan_summary_table = add_postfix(dbscan_table, '_summary') |
| input_tbl_valid(dbscan_summary_table, 'dbscan') |
| output_tbl_valid(output_table, 'dbscan') |
| |
| cols_in_tbl_valid(source_table, [id_column], 'dbscan') |
| |
| _assert(is_var_valid(source_table, expr_point), |
| "dbscan error: {0} is an invalid column name or " |
| "expression for expr_point param".format(expr_point)) |
| |
| point_col_type = get_expr_type(expr_point, source_table) |
| _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), |
| "dbscan Error: Feature column or expression '{0}' in train table is not" |
| " a numeric array.".format(expr_point)) |
| |
| def dbscan_help(schema_madlib, message=None, **kwargs): |
| """ |
| Help function for dbscan |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if message is not None and \ |
| message.lower() in ("usage", "help", "?"): |
| help_string = """ |
| ----------------------------------------------------------------------- |
| USAGE |
| ----------------------------------------------------------------------- |
| SELECT {schema_madlib}.dbscan( |
| source_table, -- Name of the training data table |
| output_table, -- Name of the output table |
| id_column, -- Name of id column in source_table |
| expr_point, -- Column name or expression for data points |
| eps, -- The minimum radius of a cluster |
| min_samples, -- The minimum size of a cluster |
| metric, -- The name of the function to use to calculate the |
| -- distance |
| algorithm -- The algorithm to use for dbscan. |
| ); |
| |
| ----------------------------------------------------------------------- |
| OUTPUT |
| ----------------------------------------------------------------------- |
| The output of the dbscan function is a table with the following columns: |
| |
| id_column The ids of test data point |
| cluster_id The id of the points associated cluster |
| is_core_point Boolean column that indicates if the point is core or not |
| points The column or expression for the data point |
| """ |
| else: |
| help_string = """ |
| ---------------------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------------------- |
| DBSCAN is a density-based clustering algorithm. Given a set of points in |
| some space, it groups together points that are closely packed together |
| (points with many nearby neighbors), marking as outliers points that lie |
| alone in low-density regions (whose nearest neighbors are too far away). |
| -- |
| For an overview on usage, run: |
| SELECT {schema_madlib}.dbscan('usage'); |
| SELECT {schema_madlib}.dbscan_predict('usage'); |
| """ |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------------------------ |
| |
| def dbscan_predict_help(schema_madlib, message=None, **kwargs): |
| """ |
| Help function for dbscan |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if message is not None and \ |
| message.lower() in ("usage", "help", "?"): |
| help_string = """ |
| ----------------------------------------------------------------------- |
| USAGE |
| ----------------------------------------------------------------------- |
| SELECT {schema_madlib}.dbscan_predict( |
| dbscan_table, -- Name of the tdbscan output table |
| new_point -- Double precision array representing the point |
| -- for prediction |
| ); |
| |
| ----------------------------------------------------------------------- |
| OUTPUT |
| ----------------------------------------------------------------------- |
| The output of the dbscan_predict is an integer indicating the cluster_id |
| of given point |
| """ |
| else: |
| help_string = """ |
| ---------------------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------------------- |
| DBSCAN is a density-based clustering algorithm. Given a set of points in |
| some space, it groups together points that are closely packed together |
| (points with many nearby neighbors), marking as outliers points that lie |
| alone in low-density regions (whose nearest neighbors are too far away). |
| -- |
| For an overview on usage, run: |
| SELECT {schema_madlib}.dbscan('usage'); |
| SELECT {schema_madlib}.dbscan_predict('usage'); |
| """ |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------------------------ |