| # coding=utf-8 |
| |
| """ |
| @file multilogistic.py_in |
| |
| @brief Multinomial Logistic Regression: Driver functions |
| |
| @namespace multilogistic |
| |
| Multinomial Logistic Regression: Driver functions |
| """ |
| |
| import plpy |
| |
| def __runIterativeAlg(stateType, initialState, source, updateExpr, |
| terminateExpr, maxNumIterations, cyclesPerIteration = 1): |
| """ |
| Driver for an iterative algorithm |
| |
| A general driver function for most iterative algorithms: The state between |
| iterations is kept in a variable of type <tt>stateType</tt>, which is |
| initialized with <tt><em>initialState</em></tt>. During each iteration, the |
| SQL statement <tt>updateSQL</tt> is executed in the database. Afterwards, |
| the SQL query <tt>updateSQL</tt> decides whether the algorithm terminates. |
| |
| @param stateType SQL type of the state between iterations |
| @param initialState The initial value of the SQL state variable |
| @param source The source relation |
| @param updateExpr SQL expression that returns the new state of type |
| <tt>stateType</tt>. The expression may use the replacement fields |
| <tt>"{state}"</tt>, <tt>"{iteration}"</tt>, and |
| <tt>"{sourceAlias}"</tt>. Source alias is an alias for the source |
| relation <tt><em>source</em></tt>. |
| @param terminateExpr SQL expression that returns whether the algorithm should |
| terminate. The expression may use the replacement fields |
| <tt>"{oldState}"</tt>, <tt>"{newState}"</tt>, and |
| <tt>"{iteration}"</tt>. It must return a BOOLEAN value. |
| @param maxNumIterations Maximum number of iterations. Algorithm will then |
| terminate even when <tt>terminateExpr</tt> does not evaluate to \c true |
| @param cyclesPerIteration Number of aggregate function calls per iteration. |
| """ |
| |
| updateSQL = """ |
| INSERT INTO _madlib_iterative_alg |
| SELECT |
| {{iteration}}, |
| {updateExpr} |
| FROM |
| _madlib_iterative_alg AS st, |
| {{source}} AS src |
| WHERE |
| st._madlib_iteration = {{iteration}} - 1 |
| """.format(updateExpr = updateExpr) |
| terminateSQL = """ |
| SELECT |
| {terminateExpr} AS should_terminate |
| FROM |
| ( |
| SELECT _madlib_state |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}} |
| ) AS older, |
| ( |
| SELECT _madlib_state |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {{iteration}} |
| ) AS newer |
| """.format(terminateExpr = terminateExpr) |
| checkForNullStateSQL = """ |
| SELECT _madlib_state IS NULL AS should_terminate |
| FROM _madlib_iterative_alg |
| WHERE _madlib_iteration = {iteration} |
| """ |
| |
| oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings WHERE name='client_min_messages'" |
| )[0]['setting'] |
| plpy.execute(""" |
| SET client_min_messages = error; |
| DROP TABLE IF EXISTS _madlib_iterative_alg; |
| CREATE TEMPORARY TABLE _madlib_iterative_alg ( |
| _madlib_iteration INTEGER PRIMARY KEY, |
| _madlib_state {stateType} |
| ); |
| SET client_min_messages = {oldMsgLevel}; |
| """.format(stateType = stateType, oldMsgLevel = oldMsgLevel)) |
| |
| iteration = 0 |
| plpy.execute(""" |
| INSERT INTO _madlib_iterative_alg VALUES ({iteration}, {initialState}) |
| """.format(iteration = iteration, initialState = initialState)) |
| while True: |
| iteration = iteration + 1 |
| plpy.execute(updateSQL.format( |
| source = source, |
| state = "(st._madlib_state)", |
| iteration = iteration, |
| sourceAlias = "src")) |
| if plpy.execute(checkForNullStateSQL.format( |
| iteration = iteration))[0]['should_terminate'] or ( |
| iteration > cyclesPerIteration and ( |
| iteration >= cyclesPerIteration * maxNumIterations or |
| plpy.execute(terminateSQL.format( |
| iteration = iteration, |
| cyclesPerIteration = cyclesPerIteration, |
| oldState = "(older._madlib_state)", |
| newState = "(newer._madlib_state)"))[0]['should_terminate'])): |
| break |
| |
| # Note: We do not drop the temporary table |
| return iteration |
| |
| |
| def compute_mlogregr(schema_madlib, source, depvar, numcategories, indepvar, optimizer, |
| maxnumiterations, precision, ref_category, **kwargs): |
| """ |
| Compute logistic regression coefficients |
| |
| This method serves as an interface to different optimization algorithms. |
| By default, iteratively reweighted least squares is used, but for data with |
| a lot of columns the conjugate-gradient method might perform better. |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @param source Name of relation containing the training data |
| @param depvar Name of dependent column in training data (of type INTEGER) |
| @param numcategories Number of categories in the multilogistic regression |
| @param indepvar Name of independent column in training data (of type |
| DOUBLE PRECISION[]) |
| @param optimizer Name of the optimizer. 'newton' or 'irls': Iteratively |
| reweighted least squares |
| @param maxnumiterations Maximum number of iterations |
| @param precision Terminate if two consecutive iterations have a difference |
| in the log-likelihood of less than <tt>precision</tt>. In other |
| words, we terminate if the objective function value has converged. |
| This convergence criterion can be disabled by specifying a negative |
| value. |
| @param ref_category The user-specified reference category |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| |
| @return array with coefficients in case of convergence, otherwise None |
| """ |
| |
| if maxnumiterations < 1: |
| plpy.error("Number of iterations must be positive") |
| |
| if optimizer == 'newton': |
| optimizer = 'irls' |
| elif optimizer not in ['irls']: |
| plpy.error("Unknown optimizer requested. Must be 'newton' or 'irls'") |
| |
| return __runIterativeAlg( |
| stateType = "FLOAT8[]", |
| initialState = "NULL", |
| source = source, |
| updateExpr = """ |
| {schema_madlib}.__mlogregr_{optimizer}_step( |
| ({depvar}), |
| ({numcategories}), |
| ({ref_category}), |
| ({indepvar})::FLOAT8[], |
| {{state}} |
| ) |
| """.format( |
| schema_madlib = schema_madlib, |
| depvar = depvar, |
| indepvar = indepvar, |
| numcategories = numcategories, |
| ref_category=ref_category, |
| optimizer = optimizer), |
| terminateExpr = """ |
| {schema_madlib}.__internal_mlogregr_{optimizer}_step_distance( |
| {{newState}}, {{oldState}} |
| ) < {precision} |
| """.format( |
| schema_madlib = schema_madlib, |
| optimizer = optimizer, |
| precision = precision), |
| maxNumIterations = maxnumiterations) |