| # coding=utf-8 |
| |
| """ |
| @file crf.py_in |
| |
| @brief Conditional Random Field: Driver functions |
| |
| @namespace crf |
| |
| Conditional Random Field: Driver functions |
| """ |
| |
| import plpy |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.utilities import _assert |
| |
| def __runIterativeAlg(stateType, initialState, source, updateExpr, |
| terminateExpr, maxNumIterations, cyclesPerIteration = 1): |
| """ |
| Driver for an iterative algorithm |
| |
| A general driver function for most iterative algorithms: The state between |
| iterations is kept in a variable of type <tt>stateType</tt>, which is |
| initialized with <tt><em>initialState</em></tt>. During each iteration, the |
| SQL statement <tt>updateSQL</tt> is executed in the database. Afterwards, |
| the SQL query <tt>updateSQL</tt> decides whether the algorithm terminates. |
| |
| @param stateType SQL type of the state between iterations |
| @param initialState The initial value of the SQL state variable |
| @param source The source relation |
| @param updateExpr SQL expression that returns the new state of type |
| <tt>stateType</tt>. The expression may use the replacement fields |
| <tt>"{state}"</tt>, <tt>"{iteration}"</tt>, and |
| <tt>"{sourceAlias}"</tt>. Source alias is an alias for the source |
| relation <tt><em>source</em></tt>. |
| @param terminateExpr SQL expression that returns whether the algorithm should |
| terminate. The expression may use the replacement fields |
| <tt>"{newState}"</tt> and <tt>"{iteration}"</tt>. It must return a BOOLEAN value. |
| @param maxNumIterations Maximum number of iterations. Algorithm will then |
| terminate even when <tt>terminateExpr</tt> does not evaluate to \c true |
| @param cyclesPerIteration Number of aggregate function calls per iteration. |
| """ |
| |
| updateSQL = """ |
| INSERT INTO _madlib_iterative_alg |
| SELECT |
| {{iteration}}, |
| {updateExpr} |
| FROM |
| _madlib_iterative_alg AS st, |
| {{source}} AS src |
| WHERE |
| st._madlib_iteration = {{iteration}} - 1 |
| """.format(updateExpr = updateExpr) |
| terminateSQL = """ |
| SELECT |
| {terminateExpr} AS should_terminate |
| FROM |
| ( |
| SELECT _madlib_state |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}} |
| ) AS older, |
| ( |
| SELECT _madlib_state |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {{iteration}} |
| ) AS newer |
| """.format(terminateExpr = terminateExpr) |
| checkForNullStateSQL = """ |
| SELECT _madlib_state IS NULL AS should_terminate |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {iteration} |
| """ |
| |
| oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings WHERE name='client_min_messages'" |
| )[0]['setting'] |
| plpy.execute(""" |
| SET client_min_messages = error; |
| DROP TABLE IF EXISTS _madlib_iterative_alg; |
| CREATE TEMPORARY TABLE _madlib_iterative_alg ( |
| _madlib_iteration INTEGER |
| m4_ifdef(`__HAWQ__', `', ` PRIMARY KEY'), |
| _madlib_state {stateType} |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (_madlib_iteration)'); |
| SET client_min_messages = {oldMsgLevel}; |
| """.format(stateType = stateType, oldMsgLevel = oldMsgLevel)) |
| |
| iteration = 0 |
| plpy.execute(""" |
| INSERT INTO _madlib_iterative_alg VALUES ({iteration}, {initialState}) |
| """.format(iteration = iteration, initialState = initialState)) |
| while True: |
| iteration = iteration + 1 |
| plpy.execute(updateSQL.format( |
| source = source, |
| state = "(st._madlib_state)", |
| iteration = iteration, |
| sourceAlias = "src")) |
| if plpy.execute(checkForNullStateSQL.format( |
| iteration = iteration))[0]['should_terminate'] or ( |
| iteration > cyclesPerIteration and ( |
| iteration >= cyclesPerIteration * maxNumIterations or |
| plpy.execute(terminateSQL.format( |
| iteration = iteration, |
| cyclesPerIteration = cyclesPerIteration, |
| oldState = "(older._madlib_state)", |
| newState = "(newer._madlib_state)"))[0]['should_terminate'])): |
| break |
| |
| # Note: We do not drop the temporary table |
| return iteration |
| |
| |
| def compute_lincrf(schema_madlib, source, sparse_R, dense_M, sparse_M, featureSize, tagSize, maxNumIterations, **kwargs): |
| """ |
| Compute conditional random field coefficients |
| |
| This method serves as an interface to L-BFGS optimization algorithms. |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param source Name of relation containing the training data |
| @param sparse_R Name of the sparse single state feature column (of type DOUBLE PRECISION[]) |
| @param dense_M Name of the dense two state feature column (of type DOUBLE PRECISION[]) |
| @param sparse_M Name of the sparse two state feature column (of type DOUBLE PRECISION[]) |
| @param featureSize Name of feature size column in training data (of type |
| DOUBLE PRECISION) |
| @param tagSize The size of the tag set |
| @param maxNumIterations Maximum number of iterations |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| @return array with coefficients in case of convergence, otherwise None |
| """ |
| |
| if maxNumIterations < 1: |
| plpy.error("Number of iterations must be positive") |
| |
| return __runIterativeAlg( |
| stateType = "FLOAT8[]", |
| initialState = "NULL", |
| source = source, |
| updateExpr = """ |
| {schema_madlib}.lincrf_lbfgs_step( |
| ({sparse_R})::FLOAT8[], |
| ({dense_M})::FLOAT8[], |
| ({sparse_M})::FLOAT8[], |
| ({featureSize})::FLOAT8, |
| ({tagSize})::FLOAT8, |
| {{state}} |
| ) |
| """.format( |
| schema_madlib = schema_madlib, |
| sparse_R = sparse_R, |
| dense_M = dense_M, |
| sparse_M = sparse_M, |
| featureSize = featureSize, |
| tagSize = tagSize), |
| terminateExpr = """ |
| {schema_madlib}.internal_lincrf_lbfgs_converge( |
| {{newState}}) = 0 |
| """.format( |
| schema_madlib = schema_madlib), |
| maxNumIterations = maxNumIterations) |
| |
| def lincrf_train(schema_madlib, train_feature_tbl, train_featureset_tbl, |
| label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations, **kwargs): |
| |
| tag_size = _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl, |
| crf_stats_tbl, crf_weights_tbl, max_iterations) |
| |
| rv = plpy.execute(""" |
| SELECT {schema_madlib}.compute_lincrf('{train_feature_tbl}', |
| 'sparse_r', 'dense_m', 'sparse_m', 'f_size', |
| {tag_size}, {max_iterations}) iterations |
| """.format(schema_madlib = schema_madlib, |
| train_feature_tbl = train_feature_tbl, |
| tag_size = tag_size, |
| max_iterations = max_iterations)) |
| |
| iterations = rv[0]['iterations'] |
| |
| plpy.execute("""CREATE TABLE {crf_stats_tbl} AS |
| SELECT (result).coef, (result).log_likelihood, (result).num_iterations |
| FROM |
| (SELECT {schema_madlib}.internal_lincrf_lbfgs_result(_madlib_state) AS result |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {iteration} |
| ) Q |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (num_iterations)') |
| """.format(crf_stats_tbl = crf_stats_tbl, |
| iteration = iterations, |
| schema_madlib = schema_madlib)) |
| |
| plpy.execute("""CREATE TABLE {crf_weights_tbl} AS |
| SELECT f_index id, f_name as name, |
| feature[1] prev_label_id, |
| feature[2] label_id, |
| coef[f_index + 1] weight |
| FROM |
| {crf_stats_tbl}, |
| {train_featureset_tbl} |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)') |
| """.format(crf_weights_tbl = crf_weights_tbl, |
| crf_stats_tbl = crf_stats_tbl, |
| train_featureset_tbl = train_featureset_tbl, |
| schema_madlib = schema_madlib)) |
| |
| # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans |
| plpy.execute(""" ANALYZE {crf_weights_tbl} """.format(crf_weights_tbl = crf_weights_tbl)) |
| |
| return "CRF Train successful. Results stored in the specified CRF stats and weights table" |
| |
| def _validate_columns(cols, table_name, err_msg_tbl): |
| """ |
| @brief Validate if cols exists in the table |
| """ |
| |
| _assert(columns_exist_in_table(table_name, cols), |
| "CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols))) |
| |
| def _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl, |
| crf_stats_tbl, crf_weights_tbl, max_iterations): |
| _assert(table_exists(train_feature_tbl), |
| "CRF error: Train feature table does not exist!") |
| _assert(table_exists(train_featureset_tbl), |
| "CRF error: Train featureset table does not exist!") |
| _assert(table_exists(label_tbl), |
| "CRF error: Label table does not exist!") |
| |
| _assert(max_iterations > 0, |
| "CRF error: max iterations cannot be zero or negative") |
| |
| # Validate required columns |
| _validate_columns(['doc_id', 'f_size', 'sparse_r', 'dense_m', 'sparse_m'], |
| train_feature_tbl, "feature") |
| _validate_columns(['f_index', 'f_name', 'feature'], train_featureset_tbl, "featureset") |
| _validate_columns(['id', 'label'], label_tbl, "label") |
| |
| rv = plpy.execute(""" SELECT count(*) FROM {label_tbl} |
| """.format(label_tbl = label_tbl)) |
| |
| tag_size = rv[0]['count'] |
| _assert(tag_size > 0, |
| "CRF error: Label table is empty") |
| |
| _assert(crf_stats_tbl is not None and |
| crf_stats_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid CRF stats table name") |
| _assert(crf_weights_tbl is not None and |
| crf_weights_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid CRF weights table name") |
| |
| _assert(not table_exists(crf_stats_tbl, only_first_schema=True), |
| "CRF error: CRF stats table already exist!" |
| " Please provide a different table name.") |
| _assert(not table_exists(crf_weights_tbl, only_first_schema=True), |
| "CRF error: CRF weights table already exist!" |
| " Please provide a different table name.") |
| |
| return tag_size |