| # coding=utf-8 |
| |
| """ |
| @file lmf_igd.py_in |
| |
| @brief Low-rank Matrix Factorization using IGD: Driver functions |
| |
| @namespace lmf_igd |
| |
| @brief Low-rank Matrix Factorization using IGD: Driver functions |
| """ |
| |
| from utilities.control import IterationController2S |
| from utilities.control import OptimizerControl |
| |
| def compute_lmf_igd(schema_madlib, rel_args, rel_state, rel_source, |
| col_row, col_column, col_value, **kwargs): |
| """ |
| Driver function for Low-rank Matrix Factorization using IGD |
| |
| @param schema_madlib Name of the MADlib schema, properly escaped/quoted |
| @rel_args Name of the (temporary) table containing all non-template |
| arguments |
| @rel_state Name of the (temporary) table containing the inter-iteration |
| states |
| @param rel_source Name of the relation containing input points |
| @param col_row Name of the row column |
| @param col_column Name of the column (in the matrix sense) column |
| @param col_value Name of the value column |
| @param kwargs We allow the caller to specify additional arguments (all of |
| which will be ignored though). The purpose of this is to allow the |
| caller to unpack a dictionary whose element set is a superset of |
| the required arguments by this function. |
| @return The iteration number (i.e., the key) with which to look up the |
| result in \c rel_state |
| """ |
| |
| # We disable ORCA since this function creates an edge case where |
| # the performance is worse than the planner. (MADLIB-1170) |
| |
| with OptimizerControl(False): |
| iterationCtrl = IterationController2S( |
| rel_args = rel_args, |
| rel_state = rel_state, |
| stateType = "DOUBLE PRECISION[]", |
| truncAfterIteration = False, |
| schema_madlib = schema_madlib, # Identifiers start here |
| rel_source = rel_source, |
| col_row = col_row, |
| col_column = col_column, |
| col_value = col_value) |
| with iterationCtrl as it: |
| it.iteration = 0 |
| while True: |
| it.update(""" |
| SELECT |
| {schema_madlib}.lmf_igd_step( |
| (_src.{col_row})::integer, |
| (_src.{col_column})::integer, |
| (_src.{col_value})::integer, |
| (SELECT _state FROM {rel_state} |
| WHERE _iteration = {iteration}), |
| (_args.row_dim)::integer, |
| (_args.column_dim)::integer, |
| (_args.max_rank)::integer, |
| (_args.stepsize)::FLOAT8, |
| (_args.scale_factor)::FLOAT8) |
| FROM {rel_source} AS _src, {rel_args} AS _args |
| """) |
| if it.test(""" |
| {iteration} > _args.num_iterations OR |
| {schema_madlib}.internal_lmf_igd_distance( |
| _state_previous, _state_current) < _args.tolerance |
| """): |
| break |
| return iterationCtrl.iteration |