| #!/usr/bin/env python |
| |
| """@namespace kmeans |
| |
| |
| @file kmeans.py_in |
| @brief k-Means Clustering with k-means++ centroid seeding |
| """ |
| |
| import datetime |
| import plpy |
| from math import floor, log, pow |
| |
| # ---------------------------------------- |
| # K-means global variables |
| # ---------------------------------------- |
| global output_points, output_centroids |
| |
| # ---------------------------------------- |
| # Logging |
| # ---------------------------------------- |
| def info( msg): |
| #plpy.info( datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ' : ' + msg); |
| plpy.info( msg); |
| |
| # ---------------------------------------- |
| # Quotes a string to be used as an identifier |
| # ---------------------------------------- |
| def quote_ident(val): |
| return '"' + val.replace('"', '""') + '"'; |
| |
| # ---------------------------------------- |
| # Quotes a string to be used as a literal |
| # ---------------------------------------- |
| def quote_literal(val): |
| return "'" + val.replace("'", "''") + "'"; |
| |
| # ---------------------------------------- |
| # Function to initialize K centroids |
| # ---------------------------------------- |
| def __kmeans_init( input_table, k): |
| """ |
| Creates the initial set of centroids. |
| |
| @param input_table Name of relation containing the input data points |
| @param k Number of centroids to generate |
| """ |
| |
| # Record the time |
| start = datetime.datetime.now(); |
| |
| # Find how many points we need to look through to be 99.9% sure that |
| # we have seen a point from each cluster |
| numCentroids = floor( - log( 1 - pow( 0.999, 1/float(k))) * k); |
| |
| # Create output tables - Points |
| plpy.execute( 'DROP TABLE IF EXISTS ' + output_points); |
| sql = ''' |
| CREATE TABLE ''' + output_points + ''' ( |
| pid BIGINT, |
| position MADLIB_SCHEMA.SVEC, |
| cid BIGINT |
| ) |
| '''; |
| plpy.execute( sql); |
| |
| # Create output tables - Centroids |
| plpy.execute( 'DROP TABLE IF EXISTS ' + output_centroids); |
| sql = ''' |
| CREATE TABLE ''' + output_centroids + ''' ( |
| cid INT, |
| position MADLIB_SCHEMA.SVEC |
| ) |
| '''; |
| plpy.execute( sql); |
| |
| # Find the 1st centroid |
| info( 'Seeding ' + str(k) + ' centroids...'); |
| sql = ''' |
| INSERT INTO ''' + output_centroids + ''' (cid, position) |
| SELECT 1, position::MADLIB_SCHEMA.SVEC |
| FROM ''' + input_table + ''' |
| ORDER BY random() DESC |
| LIMIT 1 |
| '''; |
| plpy.execute( sql); |
| |
| # Find the remaining centroids |
| i = 1 |
| while (i < k): |
| i = i + 1; |
| sql = ''' |
| INSERT INTO ''' + output_centroids + ''' (cid, position) |
| SELECT ''' + str(i) + ''', pt.position |
| FROM |
| ''' + input_table + ''' pt, |
| ( |
| -- Q1: Find a random point |
| -- (that is furthest away from current Centroids) |
| -- |
| SELECT p.pid, min( MADLIB_SCHEMA.l2norm (p.position - c.position)) * (random()^(0.2)) AS distance |
| FROM |
| (SELECT position, pid FROM ''' + input_table + ''' ORDER BY random() LIMIT ''' + str(numCentroids) + ''') AS p -- K random points |
| CROSS JOIN |
| (SELECT position FROM ''' + output_centroids + ''') AS c -- current centroids |
| GROUP BY p.pid ORDER BY distance DESC LIMIT 1 |
| ) AS pc |
| -- |
| -- Q1: end |
| -- |
| WHERE pc.pid = pt.pid |
| '''; |
| plpy.execute( sql); |
| # info( "Centroid " + str(i) + " seeded."); |
| |
| # Runtime evaluation |
| end = datetime.datetime.now(); |
| minutes, seconds = divmod( (end - start).seconds, 60) |
| microsec = (end - start).microseconds |
| |
| if ('caller' in globals()): |
| # If called from KMEANS_RUN |
| if ( caller == 'kmeans_run'): return i; |
| else: |
| # If called directly |
| return ''' |
| Finished seeding %d centroids for %s table/view. |
| Time elapsed: %d minutes %d.%d seconds. |
| ''' % (i, input_table, minutes, seconds, microsec) |
| |
| |
| # ---------------------------------------- |
| # Function to run the k-means algorithm |
| # ---------------------------------------- |
| def kmeans_run( input_table, k, goodness, run_id, output_schema): |
| """ |
| Executes the k-means clustering algorithm. |
| |
| @param input_table Name of relation containing the input data points |
| @param k Number of centroids to generate |
| @param goodness Goodness of fit test flag (allowed values: 0,1) |
| @param run_id Name/ID of the execution |
| @param output_schema Target schema for the output tables. |
| """ |
| |
| # Record the time |
| start = datetime.datetime.now(); |
| |
| # |
| # Adjustable Variables |
| # |
| sampling = 1; # if set to 1 core sampling is on otherwise algorithm will be executed on the full data set |
| sampling_size = 100; # make the sample size sufficient to have at least 'sampling_size' elements from each cluster with p = .999 |
| max_sample_size = 10000000; # maximum sample size |
| change_pct_limit = 0.001; # % of points to change assigment |
| max_iterations = 20; # Maximum number of allowed iterations |
| |
| # |
| # Non-Adjustable Variables |
| # |
| p_count = 0; # number of input points |
| c_count = 0; # number of initial centroids |
| c_count_final = 0; # number of final centroids |
| change_pct = [100.0]; # error array |
| sample_size = 0; # sample size - auto generated |
| expand = 0; # set to 1 if the k-means was executed on a sample set |
| done = 0; # loop control variable |
| gfit = ''; # goodness of fit description |
| |
| # |
| # Validate all parameters |
| # |
| info( 'Started kmeans_run() with parameters:'); |
| |
| # Validate parameter: k (positive integer) |
| if not(k>0): |
| plpy.error( "number of centroids must a positive integer (" + str(k) + ")\n"); |
| else: |
| info( ' * k = %s (number of centroids)' % k); |
| |
| # Validate parameter: input_table |
| if ( input_table.find('.') < 0): |
| plpy.error( "input table/view is missing a schema prefix (" + input_table + ")\n"); |
| else: |
| input_schemaname, input_tablename = input_table.split('.'); |
| # Does it exist? |
| rv = plpy.execute( "SELECT count(*) as cnt FROM information_schema.columns WHERE " |
| + "table_schema =" + quote_literal( input_schemaname) |
| + " AND table_name =" + quote_literal( input_tablename)); |
| if (rv[0]['cnt'] == 0): |
| plpy.error( "input table/view does not exists (" + input_table + ")\n"); |
| # Does it have PID column |
| rv = plpy.execute( "SELECT count(*) as cnt FROM information_schema.columns WHERE " |
| + "table_schema = " + quote_literal( input_schemaname) |
| + " AND table_name = " + quote_literal( input_tablename) |
| + " AND column_name = 'pid'"); |
| if (rv[0]['cnt'] == 0): |
| plpy.error( "input table/view does not have a PID column (" + input_table + ")\n"); |
| # Does it have POSITION column |
| rv = plpy.execute( "SELECT count(*) as cnt FROM information_schema.columns WHERE " |
| + "table_schema = " + quote_literal( input_schemaname) |
| + " AND table_name = " + quote_literal( input_tablename) |
| + " AND column_name = 'position'"); |
| if (rv[0]['cnt'] == 0): |
| plpy.error( "input table/view does not have a POSITION column (" + input_table + ")\n"); |
| # Input table is OK |
| info( ' * input_table = %s' % input_table); |
| input_table = quote_ident( input_schemaname) + '.' + quote_ident( input_tablename); |
| # But does it have more rows than requested number of centroids |
| rv = plpy.execute( "SELECT count(*) as cnt FROM " + input_table); |
| p_count = rv[0]['cnt']; |
| if (p_count == 0): |
| plpy.error( "input table is empty (" + input_table + ")\n"); |
| elif (p_count <= k): |
| plpy.error( "input table has less points (" + str(p_count) + ") " |
| + "than requested number of centroids (" + str(k) + ")\n"); |
| |
| # Validate parameter: goodness |
| if (goodness == 1): |
| info( ' * goodness = %s (GOF test on)' % str(goodness)); |
| elif (goodness == 0): |
| info( ' * goodness = %s (GOF test off)' % str(goodness)); |
| else: |
| plpy.error( 'incorrect value for goodness: ' + str(goodness) + '\n'); |
| |
| # Validate parameter: run_id |
| if (run_id == ''): |
| rv = plpy.execute( 'SELECT pg_backend_pid() as pid'); |
| run_id = str( rv[0]['pid']); |
| info( ' * run_id = %s (autogenerated)' % run_id); |
| else: |
| info( ' * run_id = %s' % run_id); |
| |
| # Validate parameter: run_id |
| if (output_schema == ''): |
| output_schema = 'public' |
| info( ' * output_schema = %s' % output_schema); |
| else: |
| rv = plpy.execute( "SELECT count(*) as cnt FROM pg_namespace WHERE nspname = '%s'" % output_schema); |
| if (rv[0]['cnt'] > 0): |
| info( ' * output_schema = %s' % output_schema); |
| else: |
| plpy.error( "output_schema does not exists (" + output_schema + ")\n"); |
| |
| |
| # Global variables |
| global output_points, output_centroids |
| |
| # Set Output tables |
| output_points = quote_ident( output_schema) + '.' + quote_ident( 'kmeans_out_points_' + run_id); |
| output_centroids = quote_ident( output_schema) + '.' + quote_ident( 'kmeans_out_centroids_' + run_id); |
| |
| # Initialize centroids |
| global caller; # set variable to indicate that kmeans_init is called from kmeans_run |
| caller = 'kmeans_run'; # as above |
| c_count = __kmeans_init( input_table, k); |
| if (c_count != k): |
| info( 'Requested %d centroids, but %d have been seeded.' % k, c_count); |
| |
| # Calculate the sample size |
| sample_size = min( int( sampling_size * floor( - log( 1 - pow( 0.999, 1/float(k))) * k)), max_sample_size); |
| |
| # Prepare either the sample or the full data set |
| plpy.execute( 'DROP TABLE IF EXISTS TempTable0'); |
| sql = ''' |
| CREATE TEMP TABLE TempTable0( |
| pid BIGINT, |
| position MADLIB_SCHEMA.SVEC, |
| cid INTEGER |
| ) |
| '''; |
| plpy.execute( sql); |
| if (sampling == 1 and sample_size < p_count): |
| info( 'Using sample data set for analysis... (' + str(sample_size) + ' out of ' + str(p_count) + ' points)'); |
| sql = ''' |
| INSERT INTO TempTable0 |
| SELECT pid, position, 0 FROM ''' + input_table + ''' |
| ORDER BY random() |
| LIMIT ''' + str( sample_size); |
| expand = 1; |
| else: |
| info( 'Using full data set for analysis (' + str(p_count) + ' points)'); |
| sql = ''' |
| INSERT INTO TempTable0 |
| SELECT pid, position, 0 FROM ''' + input_table; |
| expand = 0; |
| plpy.execute( sql); |
| |
| # Main Loop |
| i = 0; |
| while (done == 0): |
| |
| # Loop index |
| i = i + 1; |
| info( '...Iteration ' + str(i)); |
| |
| # Create a temporary array of current cetroids |
| plpy.execute( 'DROP TABLE IF EXISTS ArrayOfCentroids'); |
| sql = ''' |
| CREATE TEMP TABLE ArrayOfCentroids AS |
| SELECT array( |
| SELECT position |
| FROM |
| (SELECT x FROM generate_series( 1, ''' + str(k) + ''') x) x |
| LEFT OUTER JOIN ''' + output_centroids + ''' c |
| ON x.x = c.cid |
| ORDER BY x |
| LIMIT ''' + str(k) + ''' |
| ) as arr |
| '''; |
| plpy.execute( sql); |
| |
| # Create new temp table (i) |
| plpy.execute( 'DROP TABLE IF EXISTS TempTable' + str(i)); |
| plpy.execute( 'CREATE TEMP TABLE TempTable' + str(i) + '(like TempTable' + str(i-1) + ')'); |
| |
| # For each point assign the closest centroid |
| sql = ''' |
| INSERT INTO TempTable''' + str(i) + ''' |
| SELECT |
| p.pid, |
| p.position, |
| MADLIB_SCHEMA.__kmeans_closestID( p.position, arr.arr) as cid |
| FROM |
| TempTable''' + str(i-1) + ''' p CROSS JOIN ArrayOfCentroids arr |
| '''; |
| plpy.execute( sql); |
| |
| # Refresh the Centroids table based on current assignments |
| plpy.execute( 'TRUNCATE TABLE ' + output_centroids); |
| sql = ''' |
| INSERT INTO ''' + output_centroids + ''' |
| SELECT t.cid, t.position |
| FROM |
| ( |
| SELECT |
| MADLIB_SCHEMA.__kmeans_meanPosition( position) AS position |
| , cid AS cid |
| FROM TempTable''' + str(i) + ''' |
| GROUP BY cid |
| ) AS t'''; |
| plpy.execute( sql); |
| |
| # Calculate the number of points that changed the assignment |
| sql = ''' |
| SELECT SUM( ABS( t1.e - t2.e))::float as sum |
| FROM |
| (SELECT count( t1.pid) AS e, t1.cid FROM TempTable''' + str(i) + ''' t1 GROUP BY t1.cid) AS t1 |
| JOIN |
| (SELECT count( t2.pid) AS e, t2.cid FROM TempTable''' + str(i-1) + ''' t2 GROUP BY t2.cid) AS t2 |
| ON t1.cid = t2.cid |
| '''; |
| rv = plpy.execute( sql); |
| |
| # Drop previous temp table (i-1) |
| plpy.execute( 'DROP TABLE TempTable' + str(i-1)); |
| |
| # Add it to the tracking variable |
| if (i>1): change_pct.append( rv[0]['sum'] / p_count); |
| |
| # Exit conditions: |
| if (i>3) and (change_pct[i-4] - change_pct[i-1] < 0): |
| done = 1; |
| info( 'Exit reason: fraction of reassigned nodes has stabilized around ' + str(( change_pct[i-4] + change_pct[i-3] + change_pct[i-2] + change_pct[i-1]) / 4.0)); |
| elif (change_pct[i-1] < change_pct_limit): |
| done = 1; |
| info( 'Exit reason: fraction of reassigned nodes is smaller than the limit: ' + str(change_pct_limit)); |
| elif (i==max_iterations): |
| done = 1; |
| info( 'Exit reason: reached the maximum number of allowed iterations: ' + str(max_iterations)); |
| |
| # Main Loop - END |
| |
| # Expand to all points if executed on a sample |
| if ( expand == 1): |
| info( 'Expanding cluster assignment to all points...'); |
| sql = ''' |
| INSERT INTO ''' + output_points + ''' |
| SELECT |
| p.pid, |
| p.position, |
| MADLIB_SCHEMA.__kmeans_closestID( p.position, arr.arr) as cid |
| FROM |
| ''' + input_table + ''' p CROSS JOIN ArrayOfCentroids arr |
| '''; |
| # Fill the output table |
| else: |
| info( 'Writing final output table...'); |
| sql = ''' |
| INSERT INTO ''' + output_points + ''' |
| SELECT * FROM TempTable''' + str(i); |
| |
| plpy.execute( sql); |
| |
| # Calculate Goodness of fit |
| # This version uses sum( l2norm())/count(*) |
| # It would be better to do sum( l2norm(p,c)) / max( avg(points) - point) |
| if (goodness==1): |
| info( 'Calculating goodness of fit...'); |
| sql = ''' |
| SELECT sum( MADLIB_SCHEMA.l2norm (p.position - c.position)) / count(*) as gfit |
| FROM ''' + output_points + ''' p, ''' + output_centroids + ''' c |
| WHERE p.cid = c.cid |
| '''; |
| rv = plpy.execute( sql); |
| gfit = '%s' % rv[0]['gfit']; |
| else: |
| gfit = 'not computed'; |
| |
| # Count final number of centroids |
| rv = plpy.execute( 'SELECT count(*) as count FROM ' + output_centroids); |
| c_count_final = rv[0]['count']; |
| |
| # Runtime evaluation |
| end = datetime.datetime.now(); |
| minutes, seconds = divmod( (end - start).seconds, 60) |
| microsec = (end - start).microseconds |
| |
| return ''' |
| Finished k-means clustering for %s. |
| Results: |
| * %s centroids (goodness of fit = %s). |
| Output: |
| * table : %s |
| * table : %s |
| Time elapsed: %d minutes %d.%d seconds. |
| ''' % (input_table, c_count_final, gfit, output_centroids, output_points, minutes, seconds, microsec) |