blob: 2eaa1e83c8455a6cdc231a1343f0fc6f4533db47 [file] [log] [blame]
# coding=utf-8
"""
@file crf.py_in
@brief Conditional Random Field: Driver functions
@namespace crf
Conditional Random Field: Driver functions
"""
import plpy
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import _assert
def __runIterativeAlg(stateType, initialState, source, updateExpr,
terminateExpr, maxNumIterations, cyclesPerIteration = 1):
"""
Driver for an iterative algorithm
A general driver function for most iterative algorithms: The state between
iterations is kept in a variable of type <tt>stateType</tt>, which is
initialized with <tt><em>initialState</em></tt>. During each iteration, the
SQL statement <tt>updateSQL</tt> is executed in the database. Afterwards,
the SQL query <tt>updateSQL</tt> decides whether the algorithm terminates.
@param stateType SQL type of the state between iterations
@param initialState The initial value of the SQL state variable
@param source The source relation
@param updateExpr SQL expression that returns the new state of type
<tt>stateType</tt>. The expression may use the replacement fields
<tt>"{state}"</tt>, <tt>"{iteration}"</tt>, and
<tt>"{sourceAlias}"</tt>. Source alias is an alias for the source
relation <tt><em>source</em></tt>.
@param terminateExpr SQL expression that returns whether the algorithm should
terminate. The expression may use the replacement fields
<tt>"{newState}"</tt> and <tt>"{iteration}"</tt>. It must return a BOOLEAN value.
@param maxNumIterations Maximum number of iterations. Algorithm will then
terminate even when <tt>terminateExpr</tt> does not evaluate to \c true
@param cyclesPerIteration Number of aggregate function calls per iteration.
"""
updateSQL = """
INSERT INTO _madlib_iterative_alg
SELECT
{{iteration}},
{updateExpr}
FROM
_madlib_iterative_alg AS st,
{{source}} AS src
WHERE
st._madlib_iteration = {{iteration}} - 1
""".format(updateExpr = updateExpr)
terminateSQL = """
SELECT
{terminateExpr} AS should_terminate
FROM
(
SELECT _madlib_state
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}}
) AS older,
(
SELECT _madlib_state
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {{iteration}}
) AS newer
""".format(terminateExpr = terminateExpr)
checkForNullStateSQL = """
SELECT _madlib_state IS NULL AS should_terminate
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {iteration}
"""
oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings WHERE name='client_min_messages'"
)[0]['setting']
plpy.execute("""
SET client_min_messages = error;
DROP TABLE IF EXISTS _madlib_iterative_alg;
CREATE TEMPORARY TABLE _madlib_iterative_alg (
_madlib_iteration INTEGER PRIMARY KEY,
_madlib_state {stateType}
)
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (_madlib_iteration)');
SET client_min_messages = {oldMsgLevel};
""".format(stateType = stateType, oldMsgLevel = oldMsgLevel))
iteration = 0
plpy.execute("""
INSERT INTO _madlib_iterative_alg VALUES ({iteration}, {initialState})
""".format(iteration = iteration, initialState = initialState))
while True:
iteration = iteration + 1
plpy.execute(updateSQL.format(
source = source,
state = "(st._madlib_state)",
iteration = iteration,
sourceAlias = "src"))
if plpy.execute(checkForNullStateSQL.format(
iteration = iteration))[0]['should_terminate'] or (
iteration > cyclesPerIteration and (
iteration >= cyclesPerIteration * maxNumIterations or
plpy.execute(terminateSQL.format(
iteration = iteration,
cyclesPerIteration = cyclesPerIteration,
oldState = "(older._madlib_state)",
newState = "(newer._madlib_state)"))[0]['should_terminate'])):
break
# Note: We do not drop the temporary table
return iteration
def compute_lincrf(schema_madlib, source, sparse_R, dense_M, sparse_M, featureSize, tagSize, maxNumIterations, **kwargs):
"""
Compute conditional random field coefficients
This method serves as an interface to L-BFGS optimization algorithms.
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param source Name of relation containing the training data
@param sparse_R Name of the sparse single state feature column (of type DOUBLE PRECISION[])
@param dense_M Name of the dense two state feature column (of type DOUBLE PRECISION[])
@param sparse_M Name of the sparse two state feature column (of type DOUBLE PRECISION[])
@param featureSize Name of feature size column in training data (of type
DOUBLE PRECISION)
@param tagSize The size of the tag set
@param maxNumIterations Maximum number of iterations
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return array with coefficients in case of convergence, otherwise None
"""
if maxNumIterations < 1:
plpy.error("Number of iterations must be positive")
return __runIterativeAlg(
stateType = "FLOAT8[]",
initialState = "NULL",
source = source,
updateExpr = """
{schema_madlib}.lincrf_lbfgs_step(
({sparse_R})::FLOAT8[],
({dense_M})::FLOAT8[],
({sparse_M})::FLOAT8[],
({featureSize})::FLOAT8,
({tagSize})::FLOAT8,
{{state}}
)
""".format(
schema_madlib = schema_madlib,
sparse_R = sparse_R,
dense_M = dense_M,
sparse_M = sparse_M,
featureSize = featureSize,
tagSize = tagSize),
terminateExpr = """
{schema_madlib}.internal_lincrf_lbfgs_converge(
{{newState}}) = 0
""".format(
schema_madlib = schema_madlib),
maxNumIterations = maxNumIterations)
def lincrf_train(schema_madlib, train_feature_tbl, train_featureset_tbl,
label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations, **kwargs):
tag_size = _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl,
crf_stats_tbl, crf_weights_tbl, max_iterations)
rv = plpy.execute("""
SELECT {schema_madlib}.compute_lincrf('{train_feature_tbl}',
'sparse_r', 'dense_m', 'sparse_m', 'f_size',
{tag_size}, {max_iterations}) iterations
""".format(schema_madlib = schema_madlib,
train_feature_tbl = train_feature_tbl,
tag_size = tag_size,
max_iterations = max_iterations))
iterations = rv[0]['iterations']
plpy.execute("""CREATE TABLE {crf_stats_tbl} AS
SELECT (result).coef, (result).log_likelihood, (result).num_iterations
FROM
(SELECT {schema_madlib}.internal_lincrf_lbfgs_result(_madlib_state) AS result
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {iteration}
) Q
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (num_iterations)')
""".format(crf_stats_tbl = crf_stats_tbl,
iteration = iterations,
schema_madlib = schema_madlib))
plpy.execute("""CREATE TABLE {crf_weights_tbl} AS
SELECT f_index id, f_name as name,
feature[1] prev_label_id,
feature[2] label_id,
coef[f_index + 1] weight
FROM
{crf_stats_tbl},
{train_featureset_tbl}
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)')
""".format(crf_weights_tbl = crf_weights_tbl,
crf_stats_tbl = crf_stats_tbl,
train_featureset_tbl = train_featureset_tbl,
schema_madlib = schema_madlib))
# Enforce ANALYZE to gather proper table statistics required to generate optimized query plans
plpy.execute(""" ANALYZE {crf_weights_tbl} """.format(crf_weights_tbl = crf_weights_tbl))
return "CRF Train successful. Results stored in the specified CRF stats and weights table"
def _validate_columns(cols, table_name, err_msg_tbl):
"""
@brief Validate if cols exists in the table
"""
_assert(columns_exist_in_table(table_name, cols),
"CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols)))
def _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl,
crf_stats_tbl, crf_weights_tbl, max_iterations):
_assert(table_exists(train_feature_tbl),
"CRF error: Train feature table does not exist!")
_assert(table_exists(train_featureset_tbl),
"CRF error: Train featureset table does not exist!")
_assert(table_exists(label_tbl),
"CRF error: Label table does not exist!")
_assert(max_iterations > 0,
"CRF error: max iterations cannot be zero or negative")
# Validate required columns
_validate_columns(['doc_id', 'f_size', 'sparse_r', 'dense_m', 'sparse_m'],
train_feature_tbl, "feature")
_validate_columns(['f_index', 'f_name', 'feature'], train_featureset_tbl, "featureset")
_validate_columns(['id', 'label'], label_tbl, "label")
rv = plpy.execute(""" SELECT count(*) FROM {label_tbl}
""".format(label_tbl = label_tbl))
tag_size = rv[0]['count']
_assert(tag_size > 0,
"CRF error: Label table is empty")
_assert(crf_stats_tbl is not None and
crf_stats_tbl.lower().strip() not in ('null', ''),
"CRF error: Invalid CRF stats table name")
_assert(crf_weights_tbl is not None and
crf_weights_tbl.lower().strip() not in ('null', ''),
"CRF error: Invalid CRF weights table name")
_assert(not table_exists(crf_stats_tbl, only_first_schema=True),
"CRF error: CRF stats table already exist!"
" Please provide a different table name.")
_assert(not table_exists(crf_weights_tbl, only_first_schema=True),
"CRF error: CRF weights table already exist!"
" Please provide a different table name.")
return tag_size