blob: 30e4005f83fa62a0741cbd41a5632dce60e91565 [file] [log] [blame]
# coding=utf-8
m4_changequote(`<!', `!>')
"""
@file kmeans.py_in
@brief k-Means: Driver functions
@namespace kmeans
@brief k-Means: Driver functions
"""
import plpy
import re
from utilities.control import IterationController2D
from utilities.control import MinWarning
from utilities.control_composite import IterationControllerComposite
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.utilities import _assert
from utilities.utilities import unique_string
HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
# ----------------------------------------------------------------------
def kmeans_validate_src(schema_madlib, rel_source, **kwargs):
if rel_source is None or rel_source.strip().lower() in ('null', ''):
plpy.error("kmeans error: Invalid data table name!")
if not table_exists(rel_source):
plpy.error("kmeans error: Data table does not exist!")
if table_is_empty(rel_source):
plpy.error("kmeans error: Data table is empty!")
# ----------------------------------------------------------------------
def kmeans_validate_expr(schema_madlib, rel_source, expr_point, **kwargs):
"""
Validation function for the expr_point parameter
expr_point accepts 2 formats:
- A single column name of a numeric array
- A numeric array expression
"""
expr_type = get_expr_type(expr_point, rel_source).lower()
# Both formats should return a numeric array type
if expr_type.endswith(('smallint[]', 'integer[]', 'bigint[]', 'decimal[]',
'numeric[]', 'real[]', 'double precision[]',
'serial[]', 'bigserial[]', 'float8[]',
'svec')):
if columns_exist_in_table(rel_source, [expr_point]):
# An array expression would fail this check
return False
else:
return True
else:
plpy.error("Kmeans error: {0} is not a valid "
"column or array".format(expr_point))
# ----------------------------------------------------------------------
def compute_kmeanspp_seeding(schema_madlib, rel_args, rel_state, rel_source,
expr_point, **kwargs):
"""
Driver function for k-Means++ seeding
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@rel_args Name of the (temporary) table containing all non-template
arguments
@rel_state Name of the (temporary) table containing the inter-iteration
states
@param rel_source Name of the relation containing input points
@param expr_point Expression containing the point coordinates
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return The iteration number (i.e., the key) with which to look up the
result in \c rel_state
"""
if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
view_name = unique_string('km_view')
plpy.execute(""" CREATE TEMP VIEW {view_name} AS
SELECT {expr_point} AS expr FROM {rel_source}
""".format(**locals()))
rel_source = view_name
expr_point = 'expr'
fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " + rel_args)[0]['fn_dist_name']
iterationCtrl = IterationController2D(
rel_args=rel_args,
fn_dist_name=fn_dist_name,
rel_state=rel_state,
stateType="DOUBLE PRECISION[][]",
truncAfterIteration=True,
schema_madlib=schema_madlib, # Identifiers start here
rel_source=rel_source,
expr_point=expr_point)
with iterationCtrl as it:
fn_dist_str = plpy.execute("SELECT fn_dist as t FROM " + rel_args)[0]['t']
state_str = """(SELECT _state FROM {rel_state} WHERE _iteration = {iteration})"""
closest_column_str = "{schema_madlib}._closest_column"
if it.test("_args.initial_centroids IS NULL"):
it.update("""
SELECT
ARRAY[{schema_madlib}.weighted_sample(_src.{expr_point}::FLOAT8[], 1)] as b
FROM {rel_source} AS _src
WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8
""")
else:
it.update("""
SELECT _args.initial_centroids FROM {rel_args} AS _args
""")
while it.test("array_upper(_state._state, 1) < _args.k"):
update_query = """
SELECT
(%(state_str)s ||
{schema_madlib}.weighted_sample(
_src.{expr_point}::FLOAT8[],
1e-20 +
(
%(closest_column_str)s(
%(state_str)s
, _src.{expr_point}::FLOAT8[]
, '%(fn_dist_str)s'
, '{fn_dist_name}'
)
).distance)
)
FROM {rel_source} AS _src
WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8
""" % ({'state_str': state_str,
'fn_dist_str': fn_dist_str,
'closest_column_str': closest_column_str
})
it.update(update_query)
return iterationCtrl.iteration
# ------------------------------------------------------------------------------
def compute_kmeans_random_seeding(schema_madlib, rel_args, rel_state,
rel_source, expr_point, **kwargs):
"""
Driver function for k-Means random seeding
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@rel_args Name of the (temporary) table containing all non-template
arguments
@rel_state Name of the (temporary) table containing the inter-iteration
states
@param rel_source Name of the relation containing input points
@param expr_point Expression containing the point coordinates
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return The iteration number (i.e., the key) with which to look up the
result in \c rel_state
"""
if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
view_name = unique_string('km_view')
plpy.execute(""" CREATE TEMP VIEW {view_name} AS
SELECT {expr_point} AS expr FROM {rel_source}
""".format(**locals()))
rel_source = view_name
expr_point = 'expr'
iterationCtrl = IterationController2D(
rel_args=rel_args,
rel_state=rel_state,
stateType="DOUBLE PRECISION[][]",
truncAfterIteration=True,
schema_madlib=schema_madlib, # Identifiers start here
rel_source=rel_source,
expr_point=expr_point)
with iterationCtrl as it:
it.update("""
SELECT _args.initial_centroids FROM {rel_args} AS _args
""")
state_str = "_state._state"
state_join_str = """
, {rel_state} AS _state
WHERE
_state._iteration = {iteration}
GROUP BY
_state._state
"""
m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str))
while m > 0:
# the % formatting in below string is evaluated immediately, while the
# {} formatting is evaluated in the update function
it.update("""
SELECT
%(state_str)s || {schema_madlib}.matrix_agg(_point::FLOAT8[])
FROM (
SELECT
{expr_point} AS _point
FROM
{rel_source} AS _src
WHERE
random() < CAST(
(%(m)s + 14 + sqrt(196 + 28 * %(m)s))
/ (
SELECT count(*) FROM {rel_source}
WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8
)
AS DOUBLE PRECISION
)
AND abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8
ORDER BY random()
LIMIT %(m)s
) AS _src
%(state_join_str)s
"""%{'m': m, 'state_str': state_str, 'state_join_str': state_join_str}
)
m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str))
return iterationCtrl.iteration
# ------------------------------------------------------------------------------
def _create_temp_view_for_expr(schema_madlib, rel_source, expr_point):
"""
Create a temporary view to evaluate the expr_point.
"""
if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
view_name = unique_string('km_view')
plpy.execute(""" CREATE TEMP VIEW {view_name} AS
SELECT {expr_point} AS expr FROM {rel_source}
""".format(**locals()))
rel_source = view_name
expr_point = 'expr'
return rel_source,expr_point
def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
expr_point, agg_centroid, **kwargs):
"""
Driver function for Lloyd's k-means local-search heuristic
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@rel_args Name of the (temporary) table containing all non-template
arguments
@rel_state Name of the (temporary) table containing the inter-iteration
states
@param rel_source Name of the relation containing input points
@param expr_point Expression containing the point coordinates
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return The iteration number (i.e., the key) with which to look up the
result in \c rel_state
"""
rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
rel_source,
expr_point)
fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " +
rel_args)[0]['fn_dist_name']
iterationCtrl = IterationControllerComposite(
rel_args=rel_args,
rel_state=rel_state,
fn_dist_name=fn_dist_name,
stateType="{schema_madlib}.kmeans_state",
truncAfterIteration=False,
schema_madlib=schema_madlib,
rel_source=rel_source,
expr_point=expr_point,
agg_centroid=agg_centroid)
with iterationCtrl as it:
# Create the initial inter-iteration state of type kmeans_state
centroid_str = """SELECT (_state).centroids
FROM {rel_state} as rel_state
WHERE _iteration = {iteration}"""
old_centroid_str = """SELECT (_state).old_centroid_ids
FROM {rel_state} as rel_state
WHERE _iteration = {iteration}"""
# note the 'comma' is included in the string below
fn_dist_name_str = ", '{fn_dist_name}'"
fn_dist_str = plpy.execute("SELECT fn_dist FROM " + rel_args
)[0]["fn_dist"]
closest_column_str = "{schema_madlib}._closest_column"
# subquery to get closest centroid to given point from current iteration assignments
curr_assignment = closest_column_str + """(
(%s)
, _src.{expr_point}::FLOAT8[]
, '%s'
%s
)
""" % (centroid_str, fn_dist_str, fn_dist_name_str)
prev_centroid_str = """SELECT (_state).centroids
FROM {rel_state} as rel_state
WHERE _iteration = {iteration} - 1
"""
# subquery to get closest centroid to point from previous iteration assignments
prev_assignment_id = """(%s((%s)
, _src.{expr_point}::FLOAT8[]
, '%s'
%s
)
).column_id
""" % (closest_column_str, prev_centroid_str,
fn_dist_str,
fn_dist_name_str)
it.update("""
SELECT
CAST((_args.initial_centroids, NULL, NULL, 'Inf', 1.0) AS
{schema_madlib}.kmeans_state)
FROM {rel_args} AS _args
""")
state_str = "_state._state"
while it.test("""{iteration} < _args.max_num_iterations AND
(%s).frac_reassigned > _args.min_frac_reassigned
""" % (state_str)):
it.update("""
SELECT
CAST((
{schema_madlib}.matrix_agg(
_centroid::FLOAT8[]
ORDER BY _new_centroid_id),
array_agg(_new_centroid_id ORDER BY _new_centroid_id),
array_agg(_objective_fn ORDER BY _new_centroid_id),
sum(_objective_fn),
CAST(sum(_num_reassigned) AS DOUBLE PRECISION)
/ sum(_num_points)
) AS {schema_madlib}.kmeans_state)
FROM (
SELECT
(_new_centroid).column_id AS _new_centroid_id,
sum((_new_centroid).distance) AS _objective_fn,
count(*) AS _num_points,
sum(
CAST(
coalesce(
(CAST(
(%(old_centroid)s) AS INTEGER[]
))[(_new_centroid).column_id + 1] != _old_centroid_id,
TRUE
)
AS INTEGER
)
) AS _num_reassigned,
{agg_centroid}(_point::FLOAT8[]) AS _centroid
FROM (
SELECT
-- PostgreSQL/Greenplum tuning:
-- VOLATILE function as optimization fence
{schema_madlib}.noop(),
_src.{expr_point} AS _point,
%(curr)s AS _new_centroid,
%(prev_id)s AS _old_centroid_id
FROM {rel_source} AS _src
WHERE abs(coalesce({schema_madlib}.svec_elsum({expr_point}), 'Infinity'::FLOAT8)) < 'Infinity'::FLOAT8
AND NOT {schema_madlib}.array_contains_null(_src.{expr_point}::FLOAT8[])
) AS _points_with_assignments
GROUP BY (_new_centroid).column_id
) AS _new_centroids
""" % ({'old_centroid': old_centroid_str,
'curr': curr_assignment,
'prev_id': prev_assignment_id}))
if it.test("array_upper((%s).centroids, 1) < _args.k"%(state_str)):
it.update("""
SELECT
CAST((
(%(centroid)s) ||
{schema_madlib}.kmeanspp_seeding(
'{rel_source}',
'{expr_point}',
( SELECT CAST(k - array_upper(
(%(centroid)s), 1)
AS INTEGER)
FROM {rel_args}
),
(SELECT fn_dist_name FROM {rel_args})
),
(%(old_centroid)s),
NULL,
'Inf',
1.0
) AS {schema_madlib}.kmeans_state)
""" % ({'centroid': centroid_str,
'old_centroid': old_centroid_str}))
return iterationCtrl.iteration
def simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
expr_point, centroids, fn_dist, **kwargs):
"""
Calculate the simple silhouette score for every data point.
"""
with MinWarning("error"):
kmeans_validate_src(schema_madlib, rel_source)
output_tbl_valid(output_table, 'kmeans')
_assert(type(centroids) == list and
type(centroids[0]) == list and
len(centroids) > 1,
'kmeans: Invalid centroids shape. Centroids have to be a 2D numeric array.')
rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
rel_source,
expr_point)
plpy.execute("""
CREATE TABLE {output_table} AS
SELECT {pid}, centroids[1] AS centroid_id,
centroids[2] AS neighbor_centroid_id,
(CASE
WHEN distances[2] = 0 THEN 0
ELSE (distances[2] - distances[1]) / distances[2]
END) AS silh
FROM
(SELECT {pid},
(cc_out).column_ids::integer[] AS centroids,
(cc_out).distances::double precision[] AS distances
FROM (
SELECT {pid},
{schema_madlib}._closest_columns(
array{centroids},
{expr_point},
2,
'{fn_dist}'::REGPROC, '{fn_dist}') AS cc_out
FROM {rel_source})q1
)q2
""".format(**locals()))
def simple_silhouette_points_dbl_wrapper(schema_madlib, rel_source, output_table, pid,
expr_point, centroids, fn_dist, **kwargs):
simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
expr_point, centroids, fn_dist)
def simple_silhouette_points_str_wrapper(schema_madlib, rel_source, output_table, pid,
expr_point, centroids_table, centroids_col, fn_dist, **kwargs):
input_tbl_valid(centroids_table, 'kmeans')
columns_exist_in_table(centroids_table, centroids_col)
centroids = plpy.execute("""
SELECT {centroids_col} AS centroids FROM {centroids_table}
""".format(**locals()))[0]['centroids']
simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
expr_point, centroids, fn_dist, **kwargs)
m4_changequote(<!`!>, <!'!>)