| import plpy |
| |
| from utilities.control import IterationController2S |
| from utilities.utilities import unique_string |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| |
| def __lsvm_train(schema_madlib, rel_args, rel_state, rel_source, |
| col_ind_var, col_dep_var, parallel, **kwargs): |
| """ |
| Driver function for Linear Support Vector Machine using IGD |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @rel_args Name of the (temporary) table containing all non-template |
| arguments |
| @rel_state Name of the (temporary) table containing the inter-iteration |
| states |
| @param rel_source Name of the relation containing input points |
| @param col_ind_var Name of the independent variables column |
| @param col_dep_var Name of the dependent variable column |
| @param parallel Using parallel aggregate if True |
| @return The iteration number (i.e., the key) with which to look up the |
| result in \c rel_state |
| """ |
| agg = 'linear_svm_igd_step' |
| if not parallel: |
| agg = agg + '_serial' |
| iterationCtrl = IterationController2S( |
| rel_args=rel_args, |
| rel_state=rel_state, |
| stateType="double precision[]", |
| truncAfterIteration=False, |
| schema_madlib=schema_madlib, # Identifiers start here |
| rel_source=rel_source, |
| col_ind_var=col_ind_var, |
| col_dep_var=col_dep_var, |
| agg=agg) |
| with iterationCtrl as it: |
| it.iteration = 0 |
| while True: |
| it.update(""" |
| SELECT |
| {schema_madlib}.{agg}( |
| (_src.{col_ind_var})::FLOAT8[], |
| CASE WHEN (_src.{col_dep_var}) < 0 THEN false |
| ELSE true END, |
| m4_ifdef(`__HAWQ__', ` |
| ({{__state__}}) |
| ', ` |
| (SELECT _state FROM {rel_state} |
| WHERE _iteration = {iteration}) |
| '), |
| (_args.dimension)::INT4, |
| (_args.stepsize)::FLOAT8, |
| (_args.reg)::FLOAT8) |
| FROM {rel_source} AS _src, {rel_args} AS _args |
| """) |
| if it.test(""" |
| {iteration} > _args.num_iterations OR |
| {schema_madlib}.internal_linear_svm_igd_distance( |
| m4_ifdef(`__HAWQ__', ` |
| (SELECT _state FROM {rel_state} |
| WHERE _iteration = {iteration} - 1), |
| (SELECT _state FROM {rel_state} |
| WHERE _iteration = {iteration})',`_state_previous, _state_current')) < _args.tolerance |
| """): |
| break |
| return iterationCtrl.iteration |
| |
| # --------------------------------------------------- |
| # Function to run the linear classification algorithm |
| # --------------------------------------------------- |
| def lsvm_classification(schema_madlib, input_table, model_table, parallel, |
| verbose, eta, reg, max_iter, **kwargs): |
| """ |
| Executes the linear support vector classification algorithm. |
| |
| @param input_table Name of table/view containing the training data |
| @param model_table Name under which we want to store the learned model |
| @param parallel A flag indicating whether the system should learn multiple models in parallel |
| @param verbose Verbosity of reporting |
| @param eta Initial learning rate in (0,1] (default value is 0.1) |
| @param reg Regularization parameter, often chosen by cross-validation (default value is 0.001) |
| @param max_iter Maximum iterations to run, a positive integer (default value is 100) |
| """ |
| # parameter validation |
| input_tbl_valid(input_table, 'SVM') |
| cols_in_tbl_valid(input_table, ['id', 'ind', 'label'], 'SVM') |
| output_tbl_valid(model_table, 'SVM') |
| if eta <= 0 or eta > 1: |
| plpy.error("SVM error: eta should be in (0, 1]!") |
| if max_iter <= 0: |
| plpy.error("SVM error: max_iter should be greater than 0!") |
| |
| # applying default if NULL |
| parallel = False if parallel is None else parallel |
| verbose = False if verbose is None else verbose |
| eta = .1 if eta is None else eta |
| reg = .001 if reg is None else reg |
| max_iter = 100 if max_iter is None else max_iter |
| |
| # verbosing |
| old_msg_level = plpy.execute("select setting from pg_settings where \ |
| name='client_min_messages'")[0]['setting'] |
| if verbose: |
| plpy.info("Parameters:"); |
| plpy.info(" * input_table = %s" % input_table); |
| plpy.info(" * model_table = " + model_table); |
| plpy.info(" * parallel = " + str(parallel)); |
| plpy.info(" * eta = " + str(eta)); |
| plpy.info(" * reg = " + str(reg)); |
| plpy.execute("set client_min_messages to warning") |
| else: |
| plpy.execute("set client_min_messages to error") |
| |
| # arguments for iterating |
| ind_dim = plpy.execute("SELECT array_upper(ind, 1) AS dim FROM " + \ |
| input_table + " LIMIT 1")[0]['dim'] |
| inds = plpy.execute("SELECT count(*) AS c FROM " + input_table)[0]['c'] |
| |
| args = {'rel_args': unique_string(), |
| 'rel_state': unique_string(), |
| 'rel_source': input_table, |
| 'col_ind_var': 'ind', # hard-coded name |
| 'col_dep_var': 'label'} # hard-coded name |
| args.update(locals()) |
| |
| # temp table for iterating |
| plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp()".format(**args)) |
| plpy.execute(""" |
| DROP TABLE IF EXISTS pg_temp.{rel_args}; |
| CREATE TABLE pg_temp.{rel_args} AS |
| SELECT |
| {ind_dim} AS dimension, |
| {eta} AS stepsize, |
| {reg} AS reg, |
| {max_iter} AS num_iterations, |
| 0 AS tolerance; |
| """.format(**args)) |
| |
| # actual iterative algorithm computation |
| iteration_run = __lsvm_train(**args) |
| |
| # organizing results |
| plpy.execute(""" |
| CREATE TABLE {model_table} AS |
| SELECT |
| 1 AS id, -- for compatibility |
| (result).coefficients AS weights, |
| 1. AS wdiv, -- for compatibility |
| 0. AS wbias, -- for compatibility |
| (result).loss AS loss, |
| {iteration_run} AS iter_run |
| FROM ( |
| SELECT {schema_madlib}.internal_linear_svm_igd_result( |
| _state) AS result |
| FROM {rel_state} |
| WHERE _iteration = {iteration_run}) subq |
| """.format(iteration_run=iteration_run, **args)) |
| |
| # summarizing |
| loss = plpy.execute("SELECT loss FROM " + model_table)[0]['loss'] |
| summary = [( |
| model_table, |
| '1', # for compatibility |
| inds, |
| ind_dim, |
| loss, |
| 1., |
| 0. |
| )] |
| |
| plpy.execute("set client_min_messages to " + old_msg_level) |
| return summary; |
| |
| |
| # ------------------------------------------------------------------------------ |
| # Function to predict the labels of points in a table using a linear classifier |
| # ------------------------------------------------------------------------------ |
| def lsvm_predict_batch(schema_madlib, input_table, data_col, id_col, |
| model_table, output_table, parallel, **kwargs): |
| """ |
| Scores the data points stored in a table using a learned support vector model. |
| |
| @param input_table Name of table/view containing the data points to be scored |
| @param data_col Name of column in input_table containing the data points |
| @param id_col Name of column in input_table containing (integer) identifier for data point |
| @param model_table Name of learned model |
| @param output_table Name of table to store the results |
| @param parallel Deprecated and ignored boolean flag. (Default: None) |
| """ |
| # parameter validation |
| input_tbl_valid(input_table, 'SVM') |
| cols_in_tbl_valid(input_table, [data_col, id_col], 'SVM') |
| input_tbl_valid(model_table, 'SVM') |
| cols_in_tbl_valid(model_table, ['weights'], 'SVM') |
| output_tbl_valid(output_table, 'SVM') |
| |
| # use lsvm_predict internally |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT |
| {id_col} AS id, |
| {schema_madlib}.lsvm_predict( |
| (SELECT weights FROM {model_table}), |
| {data_col}) AS prediction |
| FROM {input_table} |
| """.format(**locals())) |
| |
| return """ |
| Finished processing data points in %s table. |
| Results are stored in %s table. |
| """ % (input_table,output_table) |