blob: 300b6d602806d20abd720fef20ff9f67bdb517ae [file] [log] [blame]
# coding=utf-8
"""
@file lmf_igd.py_in
@brief Low-rank Matrix Factorization using IGD: Driver functions
@namespace lmf_igd
@brief Low-rank Matrix Factorization using IGD: Driver functions
"""
from utilities.control import IterationController2S
from utilities.control import OptimizerControl
def compute_lmf_igd(schema_madlib, rel_args, rel_state, rel_source,
col_row, col_column, col_value, **kwargs):
"""
Driver function for Low-rank Matrix Factorization using IGD
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@rel_args Name of the (temporary) table containing all non-template
arguments
@rel_state Name of the (temporary) table containing the inter-iteration
states
@param rel_source Name of the relation containing input points
@param col_row Name of the row column
@param col_column Name of the column (in the matrix sense) column
@param col_value Name of the value column
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return The iteration number (i.e., the key) with which to look up the
result in \c rel_state
"""
# We disable ORCA since this function creates an edge case where
# the performance is worse than the planner. (MADLIB-1170)
with OptimizerControl(False):
iterationCtrl = IterationController2S(
rel_args = rel_args,
rel_state = rel_state,
stateType = "DOUBLE PRECISION[]",
truncAfterIteration = False,
schema_madlib = schema_madlib, # Identifiers start here
rel_source = rel_source,
col_row = col_row,
col_column = col_column,
col_value = col_value)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update("""
SELECT
{schema_madlib}.lmf_igd_step(
(_src.{col_row})::integer,
(_src.{col_column})::integer,
(_src.{col_value})::integer,
(SELECT _state FROM {rel_state}
WHERE _iteration = {iteration}),
(_args.row_dim)::integer,
(_args.column_dim)::integer,
(_args.max_rank)::integer,
(_args.stepsize)::FLOAT8,
(_args.scale_factor)::FLOAT8)
FROM {rel_source} AS _src, {rel_args} AS _args
""")
if it.test("""
{iteration} > _args.num_iterations OR
{schema_madlib}.internal_lmf_igd_distance(
_state_previous, _state_current) < _args.tolerance
"""):
break
return iterationCtrl.iteration