blob: 710ad340a6a13bb56e0d87ffbe9b5b5380962ba8 [file] [log] [blame]
import plpy
from utilities.control import IterationController2S
from utilities.utilities import unique_string
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
def __lsvm_train(schema_madlib, rel_args, rel_state, rel_source,
col_ind_var, col_dep_var, parallel, **kwargs):
"""
Driver function for Linear Support Vector Machine using IGD
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@rel_args Name of the (temporary) table containing all non-template
arguments
@rel_state Name of the (temporary) table containing the inter-iteration
states
@param rel_source Name of the relation containing input points
@param col_ind_var Name of the independent variables column
@param col_dep_var Name of the dependent variable column
@param parallel Using parallel aggregate if True
@return The iteration number (i.e., the key) with which to look up the
result in \c rel_state
"""
agg = 'linear_svm_igd_step'
if not parallel:
agg = agg + '_serial'
iterationCtrl = IterationController2S(
rel_args=rel_args,
rel_state=rel_state,
stateType="double precision[]",
truncAfterIteration=False,
schema_madlib=schema_madlib, # Identifiers start here
rel_source=rel_source,
col_ind_var=col_ind_var,
col_dep_var=col_dep_var,
agg=agg)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update("""
SELECT
{schema_madlib}.{agg}(
(_src.{col_ind_var})::FLOAT8[],
CASE WHEN (_src.{col_dep_var}) < 0 THEN false
ELSE true END,
m4_ifdef(`__HAWQ__', `
({{__state__}})
', `
(SELECT _state FROM {rel_state}
WHERE _iteration = {iteration})
'),
(_args.dimension)::INT4,
(_args.stepsize)::FLOAT8,
(_args.reg)::FLOAT8)
FROM {rel_source} AS _src, {rel_args} AS _args
""")
if it.test("""
{iteration} > _args.num_iterations OR
{schema_madlib}.internal_linear_svm_igd_distance(
m4_ifdef(`__HAWQ__', `
(SELECT _state FROM {rel_state}
WHERE _iteration = {iteration} - 1),
(SELECT _state FROM {rel_state}
WHERE _iteration = {iteration})',`_state_previous, _state_current')) < _args.tolerance
"""):
break
return iterationCtrl.iteration
# ---------------------------------------------------
# Function to run the linear classification algorithm
# ---------------------------------------------------
def lsvm_classification(schema_madlib, input_table, model_table, parallel,
verbose, eta, reg, max_iter, **kwargs):
"""
Executes the linear support vector classification algorithm.
@param input_table Name of table/view containing the training data
@param model_table Name under which we want to store the learned model
@param parallel A flag indicating whether the system should learn multiple models in parallel
@param verbose Verbosity of reporting
@param eta Initial learning rate in (0,1] (default value is 0.1)
@param reg Regularization parameter, often chosen by cross-validation (default value is 0.001)
@param max_iter Maximum iterations to run, a positive integer (default value is 100)
"""
# parameter validation
input_tbl_valid(input_table, 'SVM')
cols_in_tbl_valid(input_table, ['id', 'ind', 'label'], 'SVM')
output_tbl_valid(model_table, 'SVM')
if eta <= 0 or eta > 1:
plpy.error("SVM error: eta should be in (0, 1]!")
if max_iter <= 0:
plpy.error("SVM error: max_iter should be greater than 0!")
# applying default if NULL
parallel = False if parallel is None else parallel
verbose = False if verbose is None else verbose
eta = .1 if eta is None else eta
reg = .001 if reg is None else reg
max_iter = 100 if max_iter is None else max_iter
# verbosing
old_msg_level = plpy.execute("select setting from pg_settings where \
name='client_min_messages'")[0]['setting']
if verbose:
plpy.info("Parameters:");
plpy.info(" * input_table = %s" % input_table);
plpy.info(" * model_table = " + model_table);
plpy.info(" * parallel = " + str(parallel));
plpy.info(" * eta = " + str(eta));
plpy.info(" * reg = " + str(reg));
plpy.execute("set client_min_messages to warning")
else:
plpy.execute("set client_min_messages to error")
# arguments for iterating
ind_dim = plpy.execute("SELECT array_upper(ind, 1) AS dim FROM " + \
input_table + " LIMIT 1")[0]['dim']
inds = plpy.execute("SELECT count(*) AS c FROM " + input_table)[0]['c']
args = {'rel_args': unique_string(),
'rel_state': unique_string(),
'rel_source': input_table,
'col_ind_var': 'ind', # hard-coded name
'col_dep_var': 'label'} # hard-coded name
args.update(locals())
# temp table for iterating
plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp()".format(**args))
plpy.execute("""
DROP TABLE IF EXISTS pg_temp.{rel_args};
CREATE TABLE pg_temp.{rel_args} AS
SELECT
{ind_dim} AS dimension,
{eta} AS stepsize,
{reg} AS reg,
{max_iter} AS num_iterations,
0 AS tolerance;
""".format(**args))
# actual iterative algorithm computation
iteration_run = __lsvm_train(**args)
# organizing results
plpy.execute("""
CREATE TABLE {model_table} AS
SELECT
1 AS id, -- for compatibility
(result).coefficients AS weights,
1. AS wdiv, -- for compatibility
0. AS wbias, -- for compatibility
(result).loss AS loss,
{iteration_run} AS iter_run
FROM (
SELECT {schema_madlib}.internal_linear_svm_igd_result(
_state) AS result
FROM {rel_state}
WHERE _iteration = {iteration_run}) subq
""".format(iteration_run=iteration_run, **args))
# summarizing
loss = plpy.execute("SELECT loss FROM " + model_table)[0]['loss']
summary = [(
model_table,
'1', # for compatibility
inds,
ind_dim,
loss,
1.,
0.
)]
plpy.execute("set client_min_messages to " + old_msg_level)
return summary;
# ------------------------------------------------------------------------------
# Function to predict the labels of points in a table using a linear classifier
# ------------------------------------------------------------------------------
def lsvm_predict_batch(schema_madlib, input_table, data_col, id_col,
model_table, output_table, parallel, **kwargs):
"""
Scores the data points stored in a table using a learned support vector model.
@param input_table Name of table/view containing the data points to be scored
@param data_col Name of column in input_table containing the data points
@param id_col Name of column in input_table containing (integer) identifier for data point
@param model_table Name of learned model
@param output_table Name of table to store the results
@param parallel Deprecated and ignored boolean flag. (Default: None)
"""
# parameter validation
input_tbl_valid(input_table, 'SVM')
cols_in_tbl_valid(input_table, [data_col, id_col], 'SVM')
input_tbl_valid(model_table, 'SVM')
cols_in_tbl_valid(model_table, ['weights'], 'SVM')
output_tbl_valid(output_table, 'SVM')
# use lsvm_predict internally
plpy.execute("""
CREATE TABLE {output_table} AS
SELECT
{id_col} AS id,
{schema_madlib}.lsvm_predict(
(SELECT weights FROM {model_table}),
{data_col}) AS prediction
FROM {input_table}
""".format(**locals()))
return """
Finished processing data points in %s table.
Results are stored in %s table.
""" % (input_table,output_table)