# coding=utf-8

"""
@file glm.py_in

@brief Generalized Linear Models: Driver functions

@namespace glm

@brief Generalized Linear Models: Driver functions
"""
import plpy
from utilities.in_mem_group_control import GroupIterationController
from utilities.utilities import unique_string
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import _string_to_array
from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import extract_keyvalue_params
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.utilities import add_postfix
# ========================================================================


def __compute_glm(arg_dict):
    """
    Compute Generalized Linear Model coefficients

    This method serves as an interface to different optimization algorithms.
    By default, iteratively reweighted least squares is used.

    @return Number of iterations that has been run
    """

    iterationCtrl = GroupIterationController(arg_dict)
    with iterationCtrl as it:
        it.iteration = 0
        while True:
            it.update(
                """
                {schema_madlib}.__glm_{family}_{link}_agg(
                    ({col_dep_var})::double precision,
                    ({col_ind_var})::double precision[],
                    {rel_state}.{col_grp_state})
                """)
            if it.test(
                    """
                    {iteration} >= {max_iter}
                    OR {schema_madlib}.__glm_loglik_diff(
                        _state_previous, _state_current) < {tolerance}
                    """):
                it.final()
                break

    return iterationCtrl.iteration
# ========================================================================


def glm(schema_madlib, source_table, model_table, dependent_varname,
        independent_varname, family_params=None, grouping_col=None,
        optim_params=None, verbose=False, **kwargs):
    """
    Train Genralized Linear Model

    @param schema_madlib Name of the MADlib schema, properly escaped/quoted
    @param source_table Name of relation containing the training data
    @param model_table Name of relation where model will be outputted
    @param dependent_varname Name of dependent column in training data
    @param independent_varname Name of independent column in training data (of type
                   DOUBLE PRECISION[])
    @param family_params Distribution of dependent variable
    @param grouping_col String of comma delimited group-by columns
    @param optim_params Parameters for optimizer
    @param kwargs We allow the caller to specify additional arguments (all of
           which will be ignored though). The purpose of this is to allow the
           caller to unpack a dictionary whose element set is a superset of
           the required arguments by this function.

    @return A composite value which is __glm_result defined in glm.sql_in
    """
    __glm_validate_args(schema_madlib, source_table, model_table, dependent_varname,
                        independent_varname, grouping_col)

    family_params_dict = __extract_family_params(schema_madlib, family_params)
    optim_params_dict = __extract_optim_params(schema_madlib, optim_params)

    return __glm_compute(
        schema_madlib, source_table, model_table, dependent_varname,
        independent_varname, grouping_col, family_params_dict,
        optim_params_dict, verbose=verbose, **kwargs)
# ========================================================================


def __glm_validate_args(schema_madlib, tbl_source, tbl_output, col_dep_var,
                        col_ind_var, grouping_col):
    """
    Validate the arguments
    """
    input_tbl_valid(tbl_source, 'GLM')

    output_tbl_valid(tbl_output, 'GLM')

    if col_dep_var is None or col_dep_var.strip() == '':
        plpy.error("GLM error: Invalid dependent column name!")

    if col_ind_var is None or col_ind_var.strip() == '':
        plpy.error("GLM error: Invalid independent column name!")

    if grouping_col:
        cols_in_tbl_valid(tbl_source, _string_to_array_with_quotes(grouping_col), 'GLM')
        intersect = frozenset(_string_to_array(grouping_col)).intersection(
            frozenset(('coef', 'log_likelihood', 'std_err', 'z_stats',
                       'p_values', 'odds_ratios', 'condition_no',
                       'num_processed', 'num_missing_rows_skipped',
                       'variance_covariance', 'dispersion', 't_stats')))
        if len(intersect) > 0:
            plpy.error("GLM error: Conflicting grouping column name.\n"
                       "Predefined name(s) {0} are not allow!".format(
                           ', '.join(intersect)))

    return None
# ========================================================================


def __extract_family_params(schema_madlib, family_params):
    family_params_types = dict(family=str, link=str)
    family_params_dict = extract_keyvalue_params(family_params,
                                                 family_params_types)
    # we use the first element as the default link function for the family
    family_link = dict(
        poisson=["log", "identity", "sqrt"],
        gaussian=["identity", "log", "inverse"],
        gamma=["inverse", "log", "identity"],
        inverse_gaussian=["sqr_inverse", "identity", "log", "inverse"],
        binomial=["logit", "probit"])
    for k, v in family_params_dict.iteritems():
        if k == "family":
            if v not in family_link.keys():
                plpy.error("GLM error: {param_value} is not a valid "
                           "family!".format(param_value=v))

    if "family" not in family_params_dict.keys():
        plpy.error("GLM error: Required parameter family is missing!")
    if "link" in family_params_dict.keys():
        if family_params_dict["link"] not in family_link[family_params_dict["family"]]:
            plpy.error("GLM error: Invalid link function {link_func} for "
                       "family {family}!".format(link_func=family_params_dict["link"],
                                                 family=family_params_dict["family"]))
    else:
        # default link function
        family_params_dict["link"] = family_link[family_params_dict["family"]][0]

    return family_params_dict
# ========================================================================


def __extract_optim_params(schema_madlib, optim_params, module='GLM'):
    default_dict = dict(max_iter=100, optimizer='irls', tolerance=1e-6)
    optim_params_types = dict(max_iter=int, optimizer=str, tolerance=float)
    optim_params_dict = extract_keyvalue_params(optim_params,
                                                optim_params_types,
                                                default_dict)

    if optim_params_dict['max_iter'] <= 0:
        plpy.error("{0} error: max_iter must be positive!".format(module))
    if optim_params_dict['optimizer'] != 'irls':
        plpy.error("{0} error: optimizer must be irls!".format(module))
    if optim_params_dict['tolerance'] <= 0:
        plpy.error("{0} error: tolerane must be positive!".format(module))

    return optim_params_dict
# ========================================================================


def __glm_compute(schema_madlib, tbl_source, tbl_output, col_dep_var, col_ind_var,
                  grouping_col, family_params, optim_params, verbose=False, **kwargs):
    """
    Create an output table (drop if exists) that contains the Generalized Linear Model
    """
    old_msg_level = plpy.execute("""
        SELECT setting FROM pg_settings
        WHERE name='client_min_messages'
        """)[0]['setting']
    if verbose:
        plpy.execute("SET client_min_messages TO info")
    else:
        plpy.execute("SET client_min_messages TO warning")

    args = {'schema_madlib': schema_madlib,
            'rel_source': tbl_source,
            'tbl_output': tbl_output,
            'col_dep_var': col_dep_var,
            'col_ind_var': col_ind_var,
            'rel_state': unique_string(),
            'col_grp_iteration': unique_string(),
            'col_grp_state': unique_string(),
            'state_type': schema_madlib + ".bytea8"
            }
    args.update(optim_params)
    args.update(family_params)

    # return an array of dict
    # each dict has two elements: iteration number, and grouping value array
    if grouping_col:
        grouping_list = explicit_bool_to_text(
            tbl_source, _string_to_array_with_quotes(grouping_col), schema_madlib)
        for i in range(len(grouping_list)):
            grouping_list[i] += "::text"
        grouping_str = ','.join(grouping_list)
    else:
        grouping_col = None
        grouping_str = "NULL"
    args['grouping_col'] = grouping_col
    args['grouping_str'] = grouping_str

    # for binomial distribution, the dependent variable is of type boolean.
    # it's cast to integer here so that it can later be type cast to
    # double precision before computation begins.
    if family_params['family'] == 'binomial':
        args['col_dep_var'] = "(" + col_dep_var + ")::integer"

    # REAL COMPUTATION
    iteration_run = __compute_glm(args)

    if iteration_run >= optim_params['max_iter']:
        plpy.warning("GLM warning: the computation did not converge in " +
                     str(optim_params['max_iter']) + " iterations!")

    # output table
    grouping_str1 = "" if grouping_col is None else grouping_col + ","
    grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
    using_str = "" if grouping_str1 == "" else "using (" + grouping_col + ")"
    join_str = "," if grouping_str1 == "" else "join "

    if family_params['family'] in ['poisson', 'binomial']:
        res_str = """
            (result).z_stats AS z_stats,
            (result).p_values AS p_values,
            (result).dispersion AS dispersion
            """
        glm_result = "__glm_result_z_stats"
    else:
        res_str = """
            (result).z_stats AS t_stats,
            (result).p_values AS p_values,
            (result).dispersion AS dispersion
            """
        glm_result = "__glm_result_t_stats"

    plpy.execute(
        """
        CREATE TABLE {tbl_output} AS
        SELECT
            {grouping_str1}
            (result).coef AS coef,
            (result).loglik AS log_likelihood,
            (result).std_err AS std_err,
            {res_str},
           (CASE WHEN result IS NULL THEN 0
             ELSE (result).num_rows_processed
             END)::bigint AS num_rows_processed,
            (CASE WHEN result IS NULL THEN num_rows
             ELSE num_rows - (result).num_rows_processed
             END)::bigint AS num_rows_skipped,
            {col_grp_iteration}::integer AS num_iterations
        FROM (
            SELECT
                {col_grp_iteration}, {grouping_str1} result, num_rows
            FROM (
                ( SELECT
                        {grouping_str1}
                        {schema_madlib}.{glm_result}({col_grp_state}) AS result,
                        {col_grp_iteration}
                  FROM
                        {rel_state}
                ) t
                JOIN
                ( SELECT
                        {grouping_str1}
                        max({col_grp_iteration}) AS {col_grp_iteration}
                  FROM {rel_state}
                  GROUP BY {grouping_str2}
                ) s
                USING ({grouping_str1} {col_grp_iteration})
            ) q1
            {join_str}
            ( SELECT
                    {grouping_str1}
                    count(*) AS num_rows
              FROM {rel_source}
              GROUP BY {grouping_str2}
            ) q2
            {using_str}
        ) q3
        """.format(grouping_str1=grouping_str1,
                   grouping_str2=grouping_str2,
                   iteration_run=iteration_run,
                   using_str=using_str,
                   join_str=join_str,
                   res_str=res_str,
                   glm_result=glm_result,
                   **args))

    # summary table
    failed_groups = plpy.execute("""
        SELECT count(*) AS num_failed_groups
        FROM {tbl_output}
        WHERE coef IS NULL
        """.format(**args))[0]
    all_groups = plpy.execute("""
        SELECT count(*) AS num_all_groups
        FROM {tbl_output}
        """.format(**args))[0]
    total_rows = plpy.execute("""
        SELECT
            sum(num_rows_processed) AS total_rows_processed,
            sum(num_rows_skipped) AS total_rows_skipped
        FROM {tbl_output}
        """.format(tbl_output=tbl_output))[0]

    args.update(failed_groups)
    args.update(all_groups)
    args.update(total_rows)

    tbl_output_summary = add_postfix(tbl_output, "_summary")
    plpy.execute("""
        CREATE TABLE {tbl_output_summary} AS
        SELECT
            'glm'::varchar                      AS method,
            '{rel_source}'::varchar             AS source_table,
            '{tbl_output}'::varchar             AS out_table,
            $madlib_super_quote${dcol}$madlib_super_quote$::varchar
                                                AS dependent_varname,
            $madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar
                                                AS independent_varname,
            'family={family}, ' ||
            'link={link}'::varchar              AS family_params,
            {g_str}::text                       AS grouping_col,
            'optimizer={optimizer}, ' ||
            'max_iter={max_iter}, '   ||
            'tolerance={tolerance}'::varchar    AS optimizer_params,
            {num_all_groups}::integer           AS num_all_groups,
            {num_failed_groups}::integer        AS num_failed_groups,
            {total_rows_processed}::bigint      AS total_rows_processed,
            {total_rows_skipped}::bigint        AS total_rows_skipped
        """.format(g_str="'" + grouping_col + "'" if grouping_col else "NULL",
                   tbl_output_summary=tbl_output_summary,
                   dcol=col_dep_var,
                   **args))

    # clean up
    plpy.execute("""DROP TABLE IF EXISTS pg_temp.{rel_state} """.format(**args))
    plpy.execute("SET client_min_messages TO " + old_msg_level)
    return None


# ========================================================================

def glm_help_msg(schema_madlib, message, **kwargs):
    """ Help message for generalized linear regression model

    @param message A string, the help message indicator

    Returns:
      A string, contains the help message
    """
    if not message:

        help_string = """
------------------------------------------------------------------
                        SUMMARY
------------------------------------------------------------------
Generalized Linear Model:

Function to fit a generalized linear model, relating responses to linear combinations
of predictor variables.

For details on function usage:
        """
    elif message in ['usage', 'help', '?']:

        help_string = """
------------------------------------------------------------------
                        USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm(
    source_table,        -- name of input table
    model_table,         -- name of model table
    dependent_varname,   -- name of dependent variable
    independent_varname, -- names of independent variables
    family_params,       -- parameters for family distribution and link function

                            usage:
                            'family=<family_name>,link=<link_function_name>'

                            supported values include:
                            'family=poisson,link=identity|log|sqrt' (default link: log)
                            'family=binomial,link=logit|probit' (default link: logit)
                            'family=gaussian,link=identity|log|inverse' (default link: identity)
                            'family=inverse_gaussian,link=identity|log|inverse|sqr_inverse' (default link: sqr_inverse i.e. 1/mu^2)
                            'family=gamma,link=identity|log|inverse' (default link: inverse)

    grouping_col,        -- optional, default NULL, names of columns to group-by
    optimizer_params,    -- optional, parameters for optimizer

                            usage:
                            'max_iter=<max_num_iterations>,optimizer=<optimizer_name>,tolerance=<tolerance_value>'

                            default values include:
                            max_iter=100
                            optimizer='irls'
                            tolerance=1e-6

    verbose              -- optional, default FALSE, whether to print debug info
);

------------------------------------------------------------------
                        OUTPUT
------------------------------------------------------------------
The output table ('out_table' above) has the following columns:
    <...>                                    -- grouping columns
    'coef'               double precision[], -- vector of coefficients
    'log_likelihood'     double precision,   -- log likelihood
    'std_err'            double precision[], -- vector of standard errors
    'z_stats'/'t_stats'  double precision[], -- vector of z-statistics if family=Poisson or Binomial; vector of t-statistics otherwise
    'p_values'           double precision[], -- vector of p-values
    'dispersion'         double precision[], -- dispersion parameter (if z-stats is used, dispersion is set to be constant 1)
    'num_rows_processed' bigint,             -- numbers of rows processed
    'num_rows_skipped'   bigint,             -- numbers of rows skipped
    'num_iterations'     integer             -- number of iterations run

A summary table named <out_table>_summary is also created at the same time, which has:
    method               varchar, -- modeling method name: 'glm'
    source_table         varchar, -- the data source table name
    model_table          varchar, -- the output table name
    dependent_varname    varchar, -- the dependent variable
    independent_varname  varchar, -- the independent variable
    family_params        varchar, -- family distribution and link function
    grouping_col         varchar  -- grouping columns used in the regression
    optimizer_params     varchar, -- 'optimizer=...,max_iter=...,tolerance=...'
    num_all_groups       integer, -- how many groups
    num_failed_groups    integer, -- how many groups' fitting processes failed
    total_rows_processed bigint,  -- total numbers of rows processed
    total_rows_skipped   bigint,  -- total numbers of rows skipped
        """
    else:
        help_string = "No such option. Use {schema_madlib}.glm('help')"

    return help_string.format(schema_madlib=schema_madlib)


# ========================================================================

def glm_predict_help_msg(schema_madlib, message, **kwargs):
    """ Help message for glm predict function

    @param message A string, the help message indicator

    Returns:
      A string, contains the help message
    """
    if not message:

        help_string = """
----------------------------------------------------------------
                        SUMMARY
----------------------------------------------------------------
Prediction function for generalized linear regression:

Estimate the conditional mean for the new predictors. The length of input
coefficients should match the number of variables in the new predictors.

For details on function usage:
    SELECT {schema_madlib}.glm_predict('usage')

For prediction functions related to specific distributions:
    SELECT {schema_madlib}.glm_predict_poisson('help')
    SELECT {schema_madlib}.glm_predict_binomial('help')
        """
    elif message in ['usage', 'help', '?']:

        help_string = """
------------------------------------------------------------------
                        USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict(
    coef,            -- array of coefficients derived from glm() function
    col_ind_var,     -- array of independent variables for new predictors
    link             -- string indicating the link function specifid in glm()
);

------------------------------------------------------------------
                        OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated conditional
means for the new predictors.
        """
    else:
        help_string = "No such option. Use {schema_madlib}.glm_predict('help')"

    return help_string.format(schema_madlib=schema_madlib)


# ============================================================================
# Help messages for specialized prediction functions
# ============================================================================
def glm_predict_poisson_help_msg(schema_madlib, message, **kwargs):
    """ Help message for glm predict function

    @param message A string, the help message indicator

    Returns:
      A string, contains the help message
    """
    if message in ['usage', 'help', '?', '']:

        help_string = """
------------------------------------------------------------------
                        USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict_poisson(
    coef,            -- array of coefficients derived from glm() function
    col_ind_var,     -- array of independent variables for new predictors
    link             -- string indicating the link function specifid in glm()
);

------------------------------------------------------------------
                        OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated conditional
mean for the new predictors, rounded to the nearest integral value.

For more details on glm predict functions:
    SELECT {schema_madlib}.glm_predict('usage')
    """
    else:
        help_string = "No such option. Use {schema_madlib}.glm_predict_poisson('help')"

    return help_string.format(schema_madlib=schema_madlib)


def glm_predict_binomial_help_msg(schema_madlib, message, **kwargs):
    """ Help message for glm predict function

    @param message A string, the help message indicator

    Returns:
      A string, contains the help message
    """
    if message in ['usage', 'help', '?', '']:

        help_string = """
------------------------------------------------------------------
                        USAGE
------------------------------------------------------------------
SELECT {schema_madlib}.glm_predict_binomial(
    coef,            -- array of coefficients derived from glm() function
    col_ind_var,     -- array of independent variables for new predictors
    link             -- string indicating the link function specifid in glm()
);

------------------------------------------------------------------
                        OUTPUT
------------------------------------------------------------------
The output is a table with one column which gives the estimated output category
of the dependent variable as a boolean value.

For more details on glm predict functions:
    SELECT {schema_madlib}.glm_predict('usage')
    """
    else:
        help_string = "No such option. Use {schema_madlib}.glm_predict_binomial('help')"

    return help_string.format(schema_madlib=schema_madlib)
