
import plpy
from elastic_net_models import _elastic_net_gaussian_igd_train
from elastic_net_models import _elastic_net_gaussian_fista_train
from elastic_net_models import _elastic_net_binomial_fista_train
from elastic_net_models import _elastic_net_binomial_igd_train
from elastic_net_utils import _generate_warmup_lambda_sequence

from elastic_net_utils import BINOMIAL_FAMILIES, GAUSSIAN_FAMILIES, OPTIMIZERS

from utilities.validate_args import is_col_array
from utilities.utilities import is_string_formatted_as_array_expression
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import get_cols_and_types
from utilities.utilities import get_grouping_col_str
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import explicit_bool_to_text
from utilities.utilities import extract_keyvalue_params

from utilities.control import MinWarning

from utilities.utilities import _string_to_array_with_quotes
from utilities.utilities import _string_to_array
from utilities.utilities import is_psql_numeric_type
from utilities.utilities import _assert
from utilities.utilities import add_postfix
import re

from validation.internal.cross_validation import CrossValidator
# ------------------------------------------------------------------------


def elastic_net_help(schema_madlib, family_or_optimizer=None, **kwargs):
    """
    Given a response family name or optimizer name, print out the related
    information.

    If a family name is given, print out the supported optimizer together
    with its default optimizer.

    If an optimizer name is given, print out the necessary parameters.
    """
    if (family_or_optimizer is None or
            family_or_optimizer.lower() in ("help", "?")):
        return """
    ----------------------------------------------------------------
                            Summary
    ----------------------------------------------------------------
    Right now, gaussian (linear) and binomial (logistic) families
    are supported!
    --
    Run:
    SELECT {schema_madlib}.elastic_net_train('gaussian');
    or
    SELECT {schema_madlib}.elastic_net_train('binomial');
    to see more help.
    --
    Run:  SELECT {schema_madlib}.elastic_net_train('usage');
    to see how to use.
    --
    Run:  SELECT {schema_madlib}.elastic_net_train('predict');
    to see how to predict.
        """.format(schema_madlib=schema_madlib)

    if (family_or_optimizer.lower() in ('usage', 'help', '?')):
        return """
    ----------------------------------------------------------------
                            USAGE
    ----------------------------------------------------------------
    SELECT {schema_madlib}.elastic_net_train (
        'tbl_source',      -- Data table
        'tbl_result',      -- Result table
        'col_dep_var',     -- Dependent variable, can be an expression or
                                '*'
        'col_ind_var',     -- Independent variable, can be an expression
        'regress_family',  -- 'gaussian' (or 'linear'). 'binomial'
                                (or 'logistic')
        alpha,             -- Elastic net controlparameter, value in [0, 1]
        lambda_value,      -- Regularization parameter, positive
        standardize,       --Whether to normalize the ata
        'grouping_col',    -- Group by which columns. (DEFAULT: NULL)
        'optimizer',       --Name of optimizer. (DEFAUT: 'fista')
        'optimizer_params',-- Comma-separated string of optimizer parameters
        'excluded',        -- Column names excluded frm '*' (DEFAULT = NULL)
        max_iter,          -- Maximum iteration numbr (DEFAULT = 1000)
        tolerance          -- Stopping criteria (DEFAULT = 1e-4)
    );
    ----------------------------------------------------------------
                            OUTPUT
    ----------------------------------------------------------------
    The output table ('tbl_result' above) has the following columns:
    grouping_col      TEXT,               --'Distinct values of groupng_col'
    family            TEXT,               --'gaussian' or 'binomial'
    features          TEXT[],             -- All feature column names
    features_selected TEXT[],             -- Features with non-zero coefficients
    coef_nonzero      DOUBLE PRECISION[], -- Non-zero coefficients
    coef_all          DOUBLE PRECISION[], -- All coefficients
    intercept         DOUBLE PRECISION,   -- Intercept of the linear fit
    log_likelihood    DOUBLE PRECISION,   -- log-likelihood of the fit
    standardize       BOOLEAN,            -- Whether the data was standardized
                                          -- before fitting
    iteration_run     INTEGER             -- How many iteration was actually run

    If the independent variable is a column with type of array, features
    and features_selected will output indices of the array.
        """.format(schema_madlib=schema_madlib)

    if family_or_optimizer.lower() == "predict":
        return """
    ----------------------------------------------------------------
                            Prediction
    ----------------------------------------------------------------
    SELECT {schema_madlib}.elastic_net_predict(
        'regress_family', -- 'gaussian' (or 'linear'). 'binomial'
                          --  (or 'logistic') will be supported
        coefficients,     -- Fitting coefficients as a double array
        intercept,
        ind_var           -- independent variables
    ) FROM tbl_result, tbl_new_source;

    When predicting with binomial models, the return value is 1
    if the predicted result is True, and 0 if the prediction is
    False.

    OR -------------------------------------------------------------

    (1) SELECT {schema_madlib}.elastic_net_gaussian_predict (
            coefficients, intercept, ind_var
        ) FROM tbl_result, tbl_new_source;

    (2) SELECT {schema_madlib}.elastic_net_binomial_predict (
            coefficients, intercept, ind_var
        ) FROM tbl_result, tbl_new_source;


    (3) SELECT {schema_madlib}.elastic_net_binomial_prob (
            coefficients, intercept, ind_var
        ) FROM tbl_result, tbl_new_source;

        This returns probability values for the class being 'True'.

    OR -------------------------------------------------------------

    SELECT {schema_madlib}.elastic_net_predict(
        'tbl_model',      -- Result table of elastic_net_train
        'tbl_new_source', -- New data source
        'col_id',         -- Unique ID column
       'tbl_predict'      -- Prediction result
    );
    will put all prediction results into a table. This can be
    used together with cross_validation_general() function.

    When predicting with binomial models, the predicted values
    are BOOLEAN.
        """.format(schema_madlib=schema_madlib)

    if (family_or_optimizer.lower() in ("gaussian", "linear")):
        return """
    ----------------------------------------------------------------
    Fitting linear models
    ----------------------------------------------------------------
    Supported optimizer:
    (1) Incremental gradient descent method ('igd')
    (2) Fast iterative shrinkage thesholding algorithm ('fista')

    Default is 'fista'
    --
    Run:
    SELECT {schema_madlib}.elastic_net_train('optimizer');
    to see more help on each optimizer.
        """.format(schema_madlib=schema_madlib)

    if (family_or_optimizer.lower() in ("binomial", "logistic")):
        return """
    ----------------------------------------------------------------
    Fitting logistic models
    ----------------------------------------------------------------
    The dependent variable must be a BOOLEAN.

    Supported optimizer:
    (1) Incremental gradient descent method ('igd')
    (2) Fast iterative shrinkage thesholding algorithm ('fista')

    Default is 'fista'
    --
    Run:
    SELECT {schema_madlib}.elastic_net_train('optimizer');
    to see more help on each optimizer.
        """.format(schema_madlib=schema_madlib)

    if family_or_optimizer.lower() == "igd":
        return """
    ----------------------------------------------------------------
    Incremental gradient descent (IGD) method
    ----------------------------------------------------------------
    Right now, it supports fitting both linear and logistic models.

    In order to obtain sparse coefficients, a
    modified version of IGD is actually used.

    Parameters --------------------------------
    stepsize          - default is 0.01
    threshold         - default is 1e-10. When a coefficient is really
                        small, set it to be 0
    warmup            - default is False
    warmup_lambdas    - default is Null
    warmup_lambda_no  - default is 15. How many lambda's are used in
                        warm-up, will be overridden if warmup_lambdas
                        is not NULL
    warmup_tolerance  - default is the same as tolerance. The value
                        of tolerance used during warmup.
    n_folds           - default is 1, Number of cross validation folds.
                        Set this to greater than 1 if CV over lambda is required.
    validation_result - Name of the table to store the cross validation results.
    parallel          - default is True. Run the computation on
                        multiple segments or not.

    When warmup is True or if n_folds > 1, and warmup_lambdas is NULL, a series
    of lambda values will be automatically generated and used.

    Reference --------------------------------
    [1] Shai Shalev-Shwartz and Ambuj Tewari, Stochastic Methods for l1
        Regularized Loss Minimization. Proceedings of the 26th Interna-
        tional Conference on Machine Learning, Montreal, Canada, 2009.
        """

    if family_or_optimizer.lower() == "fista":
        return """
    ----------------------------------------------------------------
    Fast iterative shrinkage thesholding algorithm
    with backtracking for stepsizes
    ----------------------------------------------------------------
    Right now, it supports fitting both linear and logistic models.

    Parameters --------------------------------
    max_stepsize        - default is 4.0
    eta                 - default is 1.2, if stepsize does not work
                          stepsize/eta will be tried
    warmup              - default is False
    warmup_lambdas      - default is NULL, which means that lambda
                          values will be automatically generated
    warmup_lambda_no    - default is 15. How many lambda's are used in
                          warm-up, will be overridden if warmup_lambdas
                          is not NULL
    warmup_tolerance    - default is the same as tolerance. The value
                          of tolerance used during warmup.
    use_active_set      - default is False. Sometimes active-set method
                          can speed up the calculation.
    activeset_tolerance - default is the same as tolerance. The
                          value of tolerance used during active set
                          calculation
    random_stepsize     - default is False. Whether add some randomness
                         to the step size. Sometimes, this can speed
                         up the calculation.
    n_folds             - default is 1. Number of cross validation folds.
                          Set this to greater than 1 if CV over lambda is required.
    validation_result   - Name of the table to store the cross validation results.

    When warmup is True and warmup_lambdas is NULL, warmup_lambda_no
    of lambda values will be automatically generated and used.

    Reference --------------------------------
    [1] Beck, A. and M. Teboulle (2009), A fast iterative
        shrinkage-thresholding algorothm for linear inverse
        problems. SIAM J. on Imaging Sciences 2(1), 183-202.
        """

    # if family_or_optimizer.lower() == "newton":
    #     return "Newton method  "

    return """
    Elastic Net error: Not a supported response family or optimizer
    Run:
    SELECT {schema_madlib}.elastic_net_train();
    for help
    """.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------


def elastic_net_train(schema_madlib, source_table, model_table, dependent_varname,
                      independent_varname, regress_family, alpha, lambda_value,
                      standardize, grouping_col, optimizer,
                      optimizer_params, excluded, max_iter, tolerance,
                      **kwargs):
    """
    A wrapper for all variants of elastic net regularization.

    @param source_table        Name of data source table
    @param independent_varname       Name of independent variable column,
                             independent variable is an array
    @param dependent_varname       Name of dependent variable column
    @param model_table        Name of the table to store the results,
                             will return fitting coefficients and
                             likelihood
    @param lambda_value      The regularization parameter
    @param alpha             The elastic net parameter, [0, 1]
    @param standardize       Whether to normalize the variables
    @param regress_family    Response type, 'gaussian' or 'binomial'
    @param optimizer         The optimization algorithm, for example 'igd'
    @param optimizer_params  Parameters of the above optimizer, the format
                             is '{arg = value, ...}'::varchar[]
    @param excluded          Which variables are excluded when
                             independent_varname == "*"
    """
    with MinWarning("warning"):
        if regress_family is None:
            plpy.error("""
                       Elastic Net error: Please enter a valid response family name!
                       Run:
                       SELECT {schema_madlib}.elastic_net_train();
                       for supported response family.
                       """.format(schema_madlib=schema_madlib))

        if optimizer is None:
            plpy.error("""
                       Elastic Net error: Please enter a valid optimizer name!
                       Run:
                       SELECT {schema_madlib}.elastic_net_train('gaussian');
                       for supported optimizers.
                       """.format(schema_madlib=schema_madlib))

        regress_family = regress_family.lower()
        optimizer = optimizer.lower()

        if (regress_family not in (BINOMIAL_FAMILIES + GAUSSIAN_FAMILIES) or
                optimizer not in OPTIMIZERS):
            plpy.error("""
                Elastic Net error: Not a supported response family or supported
                optimizer of the given response family!

                Run:
                    SELECT {schema_madlib}.elastic_net_train();
                for help.
                """.format(schema_madlib=schema_madlib))

        cv_param, optimizer_params = _get_cv_optimizer_params(
            optimizer_params, alpha, lambda_value)

        args = locals()
        if cv_param['n_folds'] > 1:
            args.update(cv_param)
            _cross_validate_en(args)
        _internal_elastic_net_train(**args)
# ------------------------------------------------------------------------


def _internal_elastic_net_train(
        schema_madlib, source_table, model_table, dependent_varname,
        independent_varname, grouping_col,
        regress_family, alpha, lambda_value,
        standardize, optimizer, optimizer_params, excluded,
        max_iter, tolerance, **kwargs):

    tbl_summary = add_postfix(model_table, "_summary")

    # handle all special cases of independent_varname
    independent_varname, outstr_array = analyze_input_str(
        schema_madlib, source_table,
        independent_varname, dependent_varname, excluded)

    # get the grouping info
    reserved_cols =['regress_family', 'coef_all',
                     'features_selected',
                     'coef_nonzero', 'intercept',
                     'log_likelihood', 'standardize',
                     'iteration_run']
    grouping_str, grouping_col = get_grouping_col_str(schema_madlib, "Elastic Net",
                                                       reserved_cols,
                                                       source_table, grouping_col)

    if regress_family in GAUSSIAN_FAMILIES:
        if optimizer == OPTIMIZERS.igd:
            _elastic_net_gaussian_igd_train(
                schema_madlib, source_table, independent_varname, dependent_varname,
                model_table, tbl_summary, lambda_value, alpha, standardize,
                optimizer_params, max_iter, tolerance, outstr_array,
                grouping_str, grouping_col, **kwargs)
            return None
        if optimizer == OPTIMIZERS.fista:
            _elastic_net_gaussian_fista_train(
                schema_madlib, source_table, independent_varname, dependent_varname,
                model_table, tbl_summary, lambda_value, alpha, standardize,
                optimizer_params, max_iter, tolerance, outstr_array,
                grouping_str, grouping_col, **kwargs)
            return None
    elif regress_family in BINOMIAL_FAMILIES:
        if optimizer == OPTIMIZERS.igd:
            dependent_varname = "(" + dependent_varname + ")::boolean"
            _elastic_net_binomial_igd_train(
                schema_madlib, source_table, independent_varname, dependent_varname,
                model_table, tbl_summary, lambda_value, alpha, standardize,
                optimizer_params, max_iter, tolerance, outstr_array,
                grouping_str, grouping_col, **kwargs)
            return None
        if optimizer == OPTIMIZERS.fista:
            dependent_varname = "(" + dependent_varname + ")::boolean"
            _elastic_net_binomial_fista_train(
                schema_madlib, source_table, independent_varname, dependent_varname,
                model_table, tbl_summary, lambda_value, alpha, standardize,
                optimizer_params, max_iter, tolerance, outstr_array,
                grouping_str, grouping_col, **kwargs)
            return None
    return None
# ----------------------------------------------------------------------


def _get_cv_optimizer_params(param_str, alpha, smallest_lambda):
    cv_params_defaults = {
        "n_folds": (1, int),
        "lambda_value": (None, list),
        "alpha": ([alpha], list),
        "n_lambdas": (15, int),
        "validation_result": (None, str)
    }
    param_defaults = dict([(k, v[0]) for k, v in cv_params_defaults.items()])
    param_types = dict([(k, v[1]) for k, v in cv_params_defaults.items()])

    if not param_str:
        return param_defaults, param_str

    name_value = extract_keyvalue_params(param_str, param_types, param_defaults,
                                         ignore_invalid=True)
    if name_value['n_folds'] > 1:
        if not name_value['lambda_value']:
            if name_value['n_lambdas']:
                name_value['lambda_value'] = _generate_warmup_lambda_sequence(
                    smallest_lambda, name_value['n_lambdas'])
                # no warmup when cross validating on lambda
                param_str += ', warmup=False'
            else:
                name_value['lambda_value'] = [float(smallest_lambda)]
        else:
            name_value['lambda_value'] = map(float, name_value['lambda_value'])
            # no warmup when cross validating on lambda
            param_str += ', warmup=False'
        name_value['alpha'] = map(float, name_value['alpha'])
    return name_value, param_str
# ------------------------------------------------------------------------


def _cross_validate_en(args):
    # updating params_dict will also update args['params_dict']
    if args['n_folds'] > 1 and args['grouping_col']:
        plpy.error('Elastic Net Error: cross validation with grouping is not supported!')

    allowed_cv_params = ('lambda_value', 'alpha')  # keep trailing comma for single element
    cv_params_values = {}

    for param in allowed_cv_params:
        if isinstance(args[param], list):
            if len(args[param]) > 1:
                cv_params_values[param] = args[param]
            else:
                args[param] = args[param][0]

    if not cv_params_values and args['n_folds'] <= 1:
        # no cross validation
        return

    if not cv_params_values and args['n_folds'] > 1:
        plpy.warning('Elastic Net Warning: n_folds > 1 but no '
                     'cross validation parameter provided')
        return

    if cv_params_values and args['n_folds'] <= 1:
        plpy.error('Elastic Net Error: All parameters must be scalar '
                   'when n_folds is 0 or 1')

    scorer = 'classification' if args['regress_family'] in BINOMIAL_FAMILIES else 'regression'
    cv = CrossValidator(_internal_elastic_net_train, elastic_net_predict_all, scorer, args)
    val_res = cv.validate(cv_params_values, args['n_folds'])
    if 'validation_result' in args:
        val_res.output_tbl(args['validation_result'])
    args.update(val_res.top('sub_args'))
# ------------------------------------------------------------------------------


def _check_args(tbl_source, col_ind_var, col_dep_var):
    """
    Check arguments before analyze_input_str
    """
    if any(each_arg is None for each_arg in (tbl_source, col_ind_var, col_dep_var)):
        plpy.error("Elastic Net error: You have unsupported NULL value(s) in the arguments!")
    if any(each_arg.strip() == '' for each_arg in (tbl_source, col_ind_var, col_dep_var)):
        plpy.error("Elastic Net error: You have unsupported EMPTY value(s) in the arguments!")

    if not table_exists(tbl_source):
        plpy.error("Elastic Net error: Data table " + tbl_source + " does not exist!")

    if table_is_empty(tbl_source):
        plpy.error("Elastic Net error: Data table " + tbl_source + " is empty!")
# ------------------------------------------------------------------------


def analyze_input_str(schema_madlib, tbl_source,
                      col_ind_var, col_dep_var, excluded):
    """
    Make input strings and output strings compatible with functions

    @param tbl_source Data table
    @param col_ind_var Independent variables
    @param col_dep_var Dependent variables
    @param excluded Which variables are excluded when col_ind_var == "*"
    """
    _check_args(tbl_source, col_ind_var, col_dep_var)

    outstr_array = []
    if col_ind_var == "*":
        col_types_dict = dict(get_cols_and_types(tbl_source))
        cols = col_types_dict.keys()

        s = _string_to_array(excluded) if excluded is not None else []

        for each_col in cols:
            if each_col not in s and each_col != col_dep_var:
                outstr_array.append(each_col)

        if not outstr_array:
            plpy.error("Elastic Net error: All columns from independent variables "
                       "have been excluded")
        elif (len(outstr_array) == 1 and
                col_types_dict[outstr_array[0]].lower() == 'array'):
            col_ind_var = outstr_array[0]
            return analyze_single_input_str(schema_madlib, tbl_source,
                                            col_ind_var)
        else:
            included_col_types = [col_types_dict[i] for i in outstr_array]
            if not all(is_psql_numeric_type(i)
                       for i in included_col_types):
                plpy.error("Elastic Net error: All columns to be included in the "
                           "independent variables should be of the numeric type.")
        col_ind_var_new = "ARRAY[" + ','.join(outstr_array) + "]"
        return (col_ind_var_new, outstr_array)

    if columns_exist_in_table(tbl_source, [col_ind_var], schema_madlib):
        # if the input is a column name and not an expression
        return analyze_single_input_str(schema_madlib, tbl_source,
                                        col_ind_var, excluded)
    else:
        # if input is an expression resulting in an array output
        matched = is_string_formatted_as_array_expression(col_ind_var)
        if matched:
            # array expression starts with the word "ARRAY"
            outstr_array = _string_to_array(matched.group(1))
        else:
            # any other form of array expression
            n_feat = plpy.execute(""" SELECT array_upper({indep_var}, 1) as num_feat
                                      FROM {source} LIMIT 1
                                  """.format(indep_var=col_ind_var,
                                             source=tbl_source))[0]["num_feat"]
            outstr_array = ["[" + str(i) + "]" for i in range(1, n_feat + 1)]
        # We allow expressions for independent variables that could start with
        # something other than 'array'
        # Example use case: input independent variable of array column
        # adding an intercept could be done as '1 || x' where 'x' is array
        # of independent variables.
        # plpy.error("Elastic Net error: Independent variable format is not quite right!")
        return (col_ind_var, outstr_array)
# ------------------------------------------------------------------------


def analyze_single_input_str(schema_madlib, tbl_source, col_ind_var,
                             excluded=None):
    """
    Args:
        @param schema_madlib: string, Name of schema where MADlib is installed
        @param tbl_source: string, Name of input table
        @param col_ind_var: string, Name of independent variable
                            (must be single column name and of ARRAY type )
        @param excluded: list, Indices of elements to exclude.

    Returns:

    """
    if columns_exist_in_table(tbl_source, [col_ind_var], schema_madlib):
        # a single column is independent variable
        # which means that it is an array
        # excluded must be a string containing integers
        if not is_col_array(tbl_source, col_ind_var):
            plpy.error("Elastic Net error: The independent variable must be an array!")

        dimension = plpy.execute(
            """
            SELECT array_upper({col_ind_var}, 1) AS dimension
            FROM {tbl_source} limit 1
            """.format(tbl_source=tbl_source,
                       col_ind_var=col_ind_var))[0]["dimension"]

        if excluded is not None:
            s = _string_to_array(excluded)
            invalid_excluded = """
                               Elastic Net error: When the independent variable is
                               an array column, excluded values can only be indices
                               (i.e. between 1 and {0})""".format(dimension)
            try:
                s = [int(i) for i in s]
            except:
                plpy.error(invalid_excluded)
            if any(i < 1 or i > dimension for i in s):
                plpy.error(invalid_excluded)
        else:
            s = []

        outstr_array = ["%s[%s]" % (col_ind_var, str(i))
                        for i in range(1, dimension + 1) if i not in s]
        if s:
            col_ind_var_new = "ARRAY[" + ",".join(outstr_array) + "]"
        else:
            col_ind_var_new = col_ind_var

        return (col_ind_var_new, outstr_array)
    else:
        plpy.error("Elastic Net error: Single column name included for "
                   "independent variable is not found in source table.")
# ------------------------------------------------------------------------


def elastic_net_predict_all(schema_madlib, tbl_model, tbl_new_source,
                            col_id, tbl_predict, **kwargs):
    """
    Predict and put the result in a table. Useful for general CV
    """
    summary_table = add_postfix(tbl_model, "_summary")
    grouping_col = plpy.execute("SELECT grouping_col FROM {summary_table}".
                                format(summary_table=summary_table))[0]["grouping_col"]
    with MinWarning("error"):
        regress_family = plpy.execute("SELECT family FROM {tbl_model} ".
                                      format(tbl_model=tbl_model))[0]["family"]

        if regress_family.lower() in ("gaussian", "linear"):
            predict_func = "elastic_net_gaussian_predict"
        elif regress_family.lower() in ("binomial", "logistic"):
            predict_func = "elastic_net_binomial_predict"
        else:
            plpy.error("Elastic Net error: Not a supported response family!")

        if col_id is None or col_id == '':
            plpy.error("Elastic Net error: invalid ID column provided!")
        if columns_exist_in_table(tbl_new_source, [col_id], schema_madlib):
            elastic_net_predict_id = col_id
        else:
            elastic_net_predict_id = 'elastic_net_predict_id'

        dense_vars = plpy.execute(""" SELECT features AS fs FROM {tbl_model}
                                  """.format(tbl_model=tbl_model))[0]["fs"]
        dense_vars_str = "ARRAY[" + ", ".join(dense_vars) + "]"
        # Must be careful to avoid possible name conflicts

        if grouping_col and grouping_col != 'NULL':
            qstr = """
                CREATE TABLE {tbl_predict} AS
                    SELECT
                        {elastic_net_predict_id},
                        {schema_madlib}.{predict_func}(coef_all, intercept, ind_var)
                             AS prediction
                    FROM
                        {tbl_model} as tbl1
                        JOIN
                        (SELECT
                            {grouping_col},
                            {col_id} as {elastic_net_predict_id},
                            {dense_vars_str} as ind_var
                        FROM
                            {tbl_new_source}) tbl2
                        USING ({grouping_col})
                        ORDER BY {grouping_col}, {elastic_net_predict_id}
                """.format(**locals())
        else:
            qstr = """
            CREATE TABLE {tbl_predict} AS
                SELECT
                    {elastic_net_predict_id},
                    {schema_madlib}.{predict_func}(coef_all, intercept, ind_var)
                         AS prediction
                FROM
                    {tbl_model} as tbl1,
                    (
                        SELECT
                            {col_id} as {elastic_net_predict_id},
                            {dense_vars_str} as ind_var
                        FROM
                            {tbl_new_source}
                    ) tbl2
            """.format(**locals())
        plpy.execute(qstr)
    return None
# ------------------------------------------------------------------------
