blob: 6246ed9e1d812d83a09a43be700938f48f8a22d0 [file] [log] [blame]
import plpy
from elastic_net_utils import _process_results
from elastic_net_utils import _compute_log_likelihood
from utilities.validate_args import get_cols_and_types
def _elastic_net_generate_result(optimizer, iteration_run, **args):
"""
Generate result table for all optimizers
"""
standardize_flag = "True" if args["normalization"] else "False"
source_table = args["rel_source"]
if optimizer == "fista":
result_func = "__gaussian_fista_result({0})".format(args["col_grp_state"])
elif optimizer == "igd":
result_func = """__gaussian_igd_result({col_grp_state},
'{sq_str}'::double precision[],
{threshold}::double precision,
{tolerance}::double precision
)
""".format(col_grp_state=args["col_grp_state"],
tolerance=args["warmup_tolerance"],
threshold=args["threshold"],
sq_str=args["sq_str"])
tbl_state = "{rel_state}".format(rel_state=args["rel_state"])
grouping_column = args['grouping_col']
if grouping_column:
col_grp_key = args['col_grp_key']
grouping_str = args['grouping_str']
cols_types = dict(get_cols_and_types(args["tbl_source"]))
grouping_str1 = grouping_column + ","
select_grouping_info = ','.join([grp_col.strip() + "\t" + cols_types[grp_col.strip()]
for grp_col in grouping_column.split(',')]) + ","
out_table_qstr = """
SELECT
{grouping_str1}
(result).coefficients AS coef,
(result).intercept AS intercept
FROM
(
SELECT {schema_madlib}.{result_func} AS result, {col_grp_key}
FROM {tbl_state}
WHERE {col_grp_iteration} = {iteration_run}
) t
JOIN
(
SELECT
{grouping_str1}
array_to_string(ARRAY[{grouping_str}], ',') AS {col_grp_key}
FROM {source_table}
GROUP BY {grouping_col}, {col_grp_key}
) n_tuples_including_nulls_subq
USING ({col_grp_key})
""".format(result_func=result_func,
tbl_state=tbl_state,
grouping_col=grouping_column,
col_grp_iteration=args["col_grp_iteration"],
iteration_run=iteration_run,
grouping_str1=grouping_str1,
grouping_str=grouping_str,
col_grp_key=col_grp_key,
source_table=source_table,
schema_madlib=args["schema_madlib"])
else:
# It's a much simpler query when there is no grouping.
grouping_str1 = ""
select_grouping_info = ""
out_table_qstr = """
SELECT
(result).coefficients AS coef,
(result).intercept AS intercept
FROM (
SELECT {schema_madlib}.{result_func} AS result
FROM {tbl_state}
WHERE {col_grp_iteration} = {iteration_run}
) t
""".format(result_func=result_func, tbl_state=tbl_state,
col_grp_iteration=args["col_grp_iteration"],
iteration_run=iteration_run,
schema_madlib=args["schema_madlib"])
# Create the output table
plpy.execute("DROP TABLE IF EXISTS {tbl_result}".format(**args))
plpy.execute("""
CREATE TABLE {tbl_result} (
{select_grouping_info}
family text,
features text[],
features_selected text[],
coef_nonzero double precision[],
coef_all double precision[],
intercept double precision,
log_likelihood double precision,
standardize boolean,
iteration_run integer)
""".format(select_grouping_info=select_grouping_info, **args))
result = plpy.execute(out_table_qstr)
for res in result:
build_output_table(res, grouping_column, grouping_str1,
standardize_flag, iteration_run, **args)
# Create summary table, listing the grouping columns used.
grouping_text = "NULL" if not grouping_column else grouping_column
failed_groups = plpy.execute("""
SELECT count(*) AS num_failed_groups FROM {0} WHERE coef_all IS NULL
""".format(args['tbl_result']))[0]
all_groups = plpy.execute("SELECT count(*) AS num_all_groups FROM {0} ".
format(args['tbl_result']))[0]
args.update(failed_groups)
args.update(all_groups)
plpy.execute("""
CREATE TABLE {tbl_summary} AS
SELECT
'elastic_net'::varchar AS method,
'{tbl_source}'::varchar AS source_table,
'{tbl_result}'::varchar AS out_table,
$madlib_super_quote${col_dep_var}$madlib_super_quote$::varchar
AS dependent_varname,
$madlib_super_quote${col_ind_var}$madlib_super_quote$::varchar
AS independent_varname,
'{family}'::varchar AS family,
{alpha}::float AS alpha,
{lambda_value}::float AS lambda_value,
'{grouping_text}'::varchar AS grouping_col,
{num_all_groups}::integer AS num_all_groups,
{num_failed_groups}::integer AS num_failed_groups
""".format(grouping_text=grouping_text,
**args))
return None
def build_output_table(res, grouping_column, grouping_str1,
standardize_flag, iteration_run, **args):
"""
Insert model captured in "res" into the output table
"""
r_coef = res["coef"]
if r_coef:
if args["normalization"]:
(coef, intercept) = _restore_scale(r_coef, res["intercept"], args)
else:
coef = r_coef
intercept = res["intercept"]
(features, features_selected, dense_coef, sparse_coef) = _process_results(
coef, intercept, args["outstr_array"])
log_likelihood = _compute_log_likelihood(r_coef, res["intercept"], **args)
if grouping_column:
grouping_info = ','.join([str(res[grp_col.strip()])
for grp_col in grouping_str1.split(',')
if grp_col.strip() in res.keys()]) + ","
else:
grouping_info = ""
fquery = """
INSERT INTO {tbl_result} VALUES
({grouping_info} '{family}', '{features}'::text[], '{features_selected}'::text[],
'{dense_coef}'::double precision[], '{sparse_coef}'::double precision[],
{intercept}, {log_likelihood}, {standardize_flag}, {iteration})
""".format(features=features, features_selected=features_selected,
dense_coef=dense_coef, sparse_coef=sparse_coef,
intercept=intercept, log_likelihood=log_likelihood,
grouping_info=grouping_info,
standardize_flag=standardize_flag, iteration=iteration_run,
**args)
plpy.execute(fquery)
# ------------------------------------------------------------------------
def _restore_scale(coef, intercept, args):
"""
Restore the original scales
"""
rcoef = [0] * len(coef)
if args["family"] == "gaussian":
rintercept = float(args["y_scale"]["mean"])
elif args["family"] == "binomial":
rintercept = float(intercept)
for i in range(len(coef)):
if args["x_scales"]["std"][i] != 0:
rcoef[i] = coef[i] / args["x_scales"]["std"][i]
rintercept -= (coef[i] * args["x_scales"]["mean"][i] /
args["x_scales"]["std"][i])
return (rcoef, rintercept)