| # coding=utf-8 |
| m4_changequote(`<!', `!>') |
| |
| """ |
| @file kmeans.py_in |
| |
| @brief k-Means: Driver functions |
| |
| @namespace kmeans |
| |
| @brief k-Means: Driver functions |
| """ |
| |
| import plpy |
| import re |
| |
| from utilities.control import IterationController2D |
| from utilities.control import MinWarning |
| from utilities.control_composite import IterationControllerComposite |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| from utilities.utilities import _assert |
| from utilities.utilities import unique_string |
| |
| HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>) |
| # ---------------------------------------------------------------------- |
| |
| |
| def kmeans_validate_src(schema_madlib, rel_source, **kwargs): |
| if rel_source is None or rel_source.strip().lower() in ('null', ''): |
| plpy.error("kmeans error: Invalid data table name!") |
| if not table_exists(rel_source): |
| plpy.error("kmeans error: Data table does not exist!") |
| if table_is_empty(rel_source): |
| plpy.error("kmeans error: Data table is empty!") |
| # ---------------------------------------------------------------------- |
| |
| |
| def kmeans_validate_expr(schema_madlib, rel_source, expr_point, **kwargs): |
| """ |
| Validation function for the expr_point parameter |
| expr_point accepts 2 formats: |
| - A single column name of a numeric array |
| - A numeric array expression |
| """ |
| |
| expr_type = get_expr_type(expr_point, rel_source).lower() |
| |
| # Both formats should return a numeric array type |
| if expr_type.endswith(('smallint[]', 'integer[]', 'bigint[]', 'decimal[]', |
| 'numeric[]', 'real[]', 'double precision[]', |
| 'serial[]', 'bigserial[]', 'float8[]', |
| 'svec')): |
| if columns_exist_in_table(rel_source, [expr_point]): |
| # An array expression would fail this check |
| return False |
| else: |
| return True |
| else: |
| plpy.error("Kmeans error: {0} is not a valid " |
| "column or array".format(expr_point)) |
| # ---------------------------------------------------------------------- |
| |
| |
| def compute_kmeanspp_seeding(schema_madlib, rel_args, rel_state, rel_source, |
| expr_point, **kwargs): |
| """ |
| Driver function for k-Means++ seeding |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @rel_args Name of the (temporary) table containing all non-template |
| arguments |
| @rel_state Name of the (temporary) table containing the inter-iteration |
| states |
| @param rel_source Name of the relation containing input points |
| @param expr_point Expression containing the point coordinates |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| @return The iteration number (i.e., the key) with which to look up the |
| result in \c rel_state |
| """ |
| |
| if kmeans_validate_expr(schema_madlib, rel_source, expr_point): |
| view_name = unique_string('km_view') |
| |
| plpy.execute(""" CREATE TEMP VIEW {view_name} AS |
| SELECT {expr_point} AS expr FROM {rel_source} |
| """.format(**locals())) |
| rel_source = view_name |
| expr_point = 'expr' |
| |
| fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " + rel_args)[0]['fn_dist_name'] |
| iterationCtrl = IterationController2D( |
| rel_args=rel_args, |
| fn_dist_name=fn_dist_name, |
| rel_state=rel_state, |
| stateType="DOUBLE PRECISION[][]", |
| truncAfterIteration=True, |
| schema_madlib=schema_madlib, # Identifiers start here |
| rel_source=rel_source, |
| expr_point=expr_point) |
| with iterationCtrl as it: |
| fn_dist_str = plpy.execute("SELECT fn_dist as t FROM " + rel_args)[0]['t'] |
| |
| state_str = """(SELECT _state FROM {rel_state} WHERE _iteration = {iteration})""" |
| closest_column_str = "{schema_madlib}._closest_column" |
| |
| if it.test("_args.initial_centroids IS NULL"): |
| it.update(""" |
| SELECT |
| ARRAY[{schema_madlib}.weighted_sample(_src.{expr_point}::FLOAT8[], 1)] as b |
| FROM {rel_source} AS _src |
| WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8 |
| """) |
| else: |
| it.update(""" |
| SELECT _args.initial_centroids FROM {rel_args} AS _args |
| """) |
| while it.test("array_upper(_state._state, 1) < _args.k"): |
| update_query = """ |
| SELECT |
| (%(state_str)s || |
| {schema_madlib}.weighted_sample( |
| _src.{expr_point}::FLOAT8[], |
| 1e-20 + |
| ( |
| %(closest_column_str)s( |
| %(state_str)s |
| , _src.{expr_point}::FLOAT8[] |
| , '%(fn_dist_str)s' |
| , '{fn_dist_name}' |
| ) |
| ).distance) |
| ) |
| FROM {rel_source} AS _src |
| WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8 |
| """ % ({'state_str': state_str, |
| 'fn_dist_str': fn_dist_str, |
| 'closest_column_str': closest_column_str |
| }) |
| it.update(update_query) |
| return iterationCtrl.iteration |
| # ------------------------------------------------------------------------------ |
| |
| |
| def compute_kmeans_random_seeding(schema_madlib, rel_args, rel_state, |
| rel_source, expr_point, **kwargs): |
| """ |
| Driver function for k-Means random seeding |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @rel_args Name of the (temporary) table containing all non-template |
| arguments |
| @rel_state Name of the (temporary) table containing the inter-iteration |
| states |
| @param rel_source Name of the relation containing input points |
| @param expr_point Expression containing the point coordinates |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| @return The iteration number (i.e., the key) with which to look up the |
| result in \c rel_state |
| """ |
| if kmeans_validate_expr(schema_madlib, rel_source, expr_point): |
| view_name = unique_string('km_view') |
| |
| plpy.execute(""" CREATE TEMP VIEW {view_name} AS |
| SELECT {expr_point} AS expr FROM {rel_source} |
| """.format(**locals())) |
| rel_source = view_name |
| expr_point = 'expr' |
| |
| iterationCtrl = IterationController2D( |
| rel_args=rel_args, |
| rel_state=rel_state, |
| stateType="DOUBLE PRECISION[][]", |
| truncAfterIteration=True, |
| schema_madlib=schema_madlib, # Identifiers start here |
| rel_source=rel_source, |
| expr_point=expr_point) |
| with iterationCtrl as it: |
| it.update(""" |
| SELECT _args.initial_centroids FROM {rel_args} AS _args |
| """) |
| state_str = "_state._state" |
| state_join_str = """ |
| , {rel_state} AS _state |
| WHERE |
| _state._iteration = {iteration} |
| GROUP BY |
| _state._state |
| """ |
| m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str)) |
| while m > 0: |
| # the % formatting in below string is evaluated immediately, while the |
| # {} formatting is evaluated in the update function |
| it.update(""" |
| SELECT |
| %(state_str)s || {schema_madlib}.matrix_agg(_point::FLOAT8[]) |
| FROM ( |
| SELECT |
| {expr_point} AS _point |
| FROM |
| {rel_source} AS _src |
| WHERE |
| random() < CAST( |
| (%(m)s + 14 + sqrt(196 + 28 * %(m)s)) |
| / ( |
| SELECT count(*) FROM {rel_source} |
| WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8 |
| ) |
| AS DOUBLE PRECISION |
| ) |
| AND abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8 |
| ORDER BY random() |
| LIMIT %(m)s |
| ) AS _src |
| %(state_join_str)s |
| """%{'m': m, 'state_str': state_str, 'state_join_str': state_join_str} |
| ) |
| m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str)) |
| return iterationCtrl.iteration |
| # ------------------------------------------------------------------------------ |
| def _create_temp_view_for_expr(schema_madlib, rel_source, expr_point): |
| """ |
| Create a temporary view to evaluate the expr_point. |
| """ |
| |
| if kmeans_validate_expr(schema_madlib, rel_source, expr_point): |
| view_name = unique_string('km_view') |
| |
| plpy.execute(""" CREATE TEMP VIEW {view_name} AS |
| SELECT {expr_point} AS expr FROM {rel_source} |
| """.format(**locals())) |
| rel_source = view_name |
| expr_point = 'expr' |
| |
| return rel_source,expr_point |
| |
| def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source, |
| expr_point, agg_centroid, **kwargs): |
| """ |
| Driver function for Lloyd's k-means local-search heuristic |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @rel_args Name of the (temporary) table containing all non-template |
| arguments |
| @rel_state Name of the (temporary) table containing the inter-iteration |
| states |
| @param rel_source Name of the relation containing input points |
| @param expr_point Expression containing the point coordinates |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| @return The iteration number (i.e., the key) with which to look up the |
| result in \c rel_state |
| """ |
| |
| rel_source, expr_point = _create_temp_view_for_expr(schema_madlib, |
| rel_source, |
| expr_point) |
| |
| fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " + |
| rel_args)[0]['fn_dist_name'] |
| iterationCtrl = IterationControllerComposite( |
| rel_args=rel_args, |
| rel_state=rel_state, |
| fn_dist_name=fn_dist_name, |
| stateType="{schema_madlib}.kmeans_state", |
| truncAfterIteration=False, |
| schema_madlib=schema_madlib, |
| rel_source=rel_source, |
| expr_point=expr_point, |
| agg_centroid=agg_centroid) |
| with iterationCtrl as it: |
| # Create the initial inter-iteration state of type kmeans_state |
| centroid_str = """SELECT (_state).centroids |
| FROM {rel_state} as rel_state |
| WHERE _iteration = {iteration}""" |
| old_centroid_str = """SELECT (_state).old_centroid_ids |
| FROM {rel_state} as rel_state |
| WHERE _iteration = {iteration}""" |
| |
| # note the 'comma' is included in the string below |
| fn_dist_name_str = ", '{fn_dist_name}'" |
| fn_dist_str = plpy.execute("SELECT fn_dist FROM " + rel_args |
| )[0]["fn_dist"] |
| closest_column_str = "{schema_madlib}._closest_column" |
| |
| # subquery to get closest centroid to given point from current iteration assignments |
| curr_assignment = closest_column_str + """( |
| (%s) |
| , _src.{expr_point}::FLOAT8[] |
| , '%s' |
| %s |
| ) |
| """ % (centroid_str, fn_dist_str, fn_dist_name_str) |
| |
| prev_centroid_str = """SELECT (_state).centroids |
| FROM {rel_state} as rel_state |
| WHERE _iteration = {iteration} - 1 |
| """ |
| # subquery to get closest centroid to point from previous iteration assignments |
| prev_assignment_id = """(%s((%s) |
| , _src.{expr_point}::FLOAT8[] |
| , '%s' |
| %s |
| ) |
| ).column_id |
| """ % (closest_column_str, prev_centroid_str, |
| fn_dist_str, |
| fn_dist_name_str) |
| |
| it.update(""" |
| SELECT |
| CAST((_args.initial_centroids, NULL, NULL, 'Inf', 1.0) AS |
| {schema_madlib}.kmeans_state) |
| FROM {rel_args} AS _args |
| """) |
| state_str = "_state._state" |
| while it.test("""{iteration} < _args.max_num_iterations AND |
| (%s).frac_reassigned > _args.min_frac_reassigned |
| """ % (state_str)): |
| it.update(""" |
| SELECT |
| CAST(( |
| {schema_madlib}.matrix_agg( |
| _centroid::FLOAT8[] |
| ORDER BY _new_centroid_id), |
| array_agg(_new_centroid_id ORDER BY _new_centroid_id), |
| array_agg(_objective_fn ORDER BY _new_centroid_id), |
| sum(_objective_fn), |
| CAST(sum(_num_reassigned) AS DOUBLE PRECISION) |
| / sum(_num_points) |
| ) AS {schema_madlib}.kmeans_state) |
| FROM ( |
| SELECT |
| (_new_centroid).column_id AS _new_centroid_id, |
| sum((_new_centroid).distance) AS _objective_fn, |
| count(*) AS _num_points, |
| sum( |
| CAST( |
| coalesce( |
| (CAST( |
| (%(old_centroid)s) AS INTEGER[] |
| ))[(_new_centroid).column_id + 1] != _old_centroid_id, |
| TRUE |
| ) |
| AS INTEGER |
| ) |
| ) AS _num_reassigned, |
| {agg_centroid}(_point::FLOAT8[]) AS _centroid |
| FROM ( |
| SELECT |
| -- PostgreSQL/Greenplum tuning: |
| -- VOLATILE function as optimization fence |
| {schema_madlib}.noop(), |
| _src.{expr_point} AS _point, |
| %(curr)s AS _new_centroid, |
| %(prev_id)s AS _old_centroid_id |
| FROM {rel_source} AS _src |
| WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8 |
| AND NOT {schema_madlib}.array_contains_null(_src.{expr_point}::FLOAT8[]) |
| ) AS _points_with_assignments |
| GROUP BY (_new_centroid).column_id |
| ) AS _new_centroids |
| """ % ({'old_centroid': old_centroid_str, |
| 'curr': curr_assignment, |
| 'prev_id': prev_assignment_id})) |
| |
| if it.test("array_upper((%s).centroids, 1) < _args.k"%(state_str)): |
| it.update(""" |
| SELECT |
| CAST(( |
| (%(centroid)s) || |
| {schema_madlib}.kmeanspp_seeding( |
| '{rel_source}', |
| '{expr_point}', |
| ( SELECT CAST(k - array_upper( |
| (%(centroid)s), 1) |
| AS INTEGER) |
| FROM {rel_args} |
| ), |
| (SELECT fn_dist_name FROM {rel_args}) |
| ), |
| (%(old_centroid)s), |
| NULL, |
| 'Inf', |
| 1.0 |
| ) AS {schema_madlib}.kmeans_state) |
| """ % ({'centroid': centroid_str, |
| 'old_centroid': old_centroid_str})) |
| return iterationCtrl.iteration |
| |
| def simple_silhouette_points(schema_madlib, rel_source, output_table, pid, |
| expr_point, centroids, fn_dist, **kwargs): |
| |
| """ |
| Calculate the simple silhouette score for every data point. |
| """ |
| |
| with MinWarning("error"): |
| kmeans_validate_src(schema_madlib, rel_source) |
| output_tbl_valid(output_table, 'kmeans') |
| |
| _assert(type(centroids) == list and |
| type(centroids[0]) == list and |
| len(centroids) > 1, |
| 'kmeans: Invalid centroids shape. Centroids have to be a 2D numeric array.') |
| |
| rel_source, expr_point = _create_temp_view_for_expr(schema_madlib, |
| rel_source, |
| expr_point) |
| |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT {pid}, centroids[1] AS centroid_id, |
| centroids[2] AS neighbor_centroid_id, |
| (CASE |
| WHEN distances[2] = 0 THEN 0 |
| ELSE (distances[2] - distances[1]) / distances[2] |
| END) AS silh |
| FROM |
| (SELECT {pid}, |
| (cc_out).column_ids::integer[] AS centroids, |
| (cc_out).distances::double precision[] AS distances |
| FROM ( |
| SELECT {pid}, |
| {schema_madlib}._closest_columns( |
| array{centroids}, |
| {expr_point}, |
| 2, |
| '{fn_dist}'::REGPROC, '{fn_dist}') AS cc_out |
| FROM {rel_source})q1 |
| )q2 |
| """.format(**locals())) |
| |
| def simple_silhouette_points_dbl_wrapper(schema_madlib, rel_source, output_table, pid, |
| expr_point, centroids, fn_dist, **kwargs): |
| |
| simple_silhouette_points(schema_madlib, rel_source, output_table, pid, |
| expr_point, centroids, fn_dist) |
| |
| |
| def simple_silhouette_points_str_wrapper(schema_madlib, rel_source, output_table, pid, |
| expr_point, centroids_table, centroids_col, fn_dist, **kwargs): |
| |
| input_tbl_valid(centroids_table, 'kmeans') |
| columns_exist_in_table(centroids_table, centroids_col) |
| centroids = plpy.execute(""" |
| SELECT {centroids_col} AS centroids FROM {centroids_table} |
| """.format(**locals()))[0]['centroids'] |
| |
| simple_silhouette_points(schema_madlib, rel_source, output_table, pid, |
| expr_point, centroids, fn_dist, **kwargs) |
| |
| m4_changequote(<!`!>, <!'!>) |