blob: 71ad49411413975b9eb55516e16d77f7ce7fbd9e [file] [log] [blame]
# coding=utf-8
"""
@file logistic.py_in
@brief Logistic Regression: Driver functions
@namespace logistic
@brief Logistic Regression: Driver functions
"""
import plpy
from utilities.group_control import GroupIterationController
from utilities.utilities import __unique_string
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import scalar_col_has_no_null
from utilities.utilities import _string_to_array
# ========================================================================
def __compute_logregr(schema_madlib, rel_args, rel_state, rel_source,
dep_col, ind_col, optimizer, grouping_col,
grouping_str, **kwargs):
"""
Compute logistic regression coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used, but for data with
a lot of columns the conjugate-gradient method might perform better.
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param rel_args Stores parameters that are needed by optimizers:
max_iter and tolerance
@param rel_state Store the iteration states
@param rel_source Name of relation containing the training data
@param dep_col Name of dependent column in training data (of type BOOLEAN)
@param ind_col Name of independent column in training data (of type
DOUBLE PRECISION[])
@param optimizer Name of the optimizer. 'newton' or 'irls': Iteratively
reweighted least squares, 'cg': conjugate gradient or
'igd': incremental gradient descent
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return Number of iterations that has been run
"""
iterationCtrl = GroupIterationController(
rel_args=rel_args,
rel_state=rel_state,
stateType="double precision[]",
# truncAfterIteration=False,
schema_madlib=schema_madlib, # Identifiers start here
rel_source=rel_source,
ind_col=ind_col,
dep_col=dep_col,
optimizer=optimizer,
grouping_col = grouping_col,
grouping_str = grouping_str)
with iterationCtrl as it:
it.iteration = 0
while True:
it.update(
"""
{schema_madlib}.__logregr_{optimizer}_step(
({dep_col})::boolean,
({ind_col})::double precision[],
{rel_state}._state)
""")
if it.test(
"""
{iteration} >= _args.max_iter
or
{schema_madlib}.__logregr_{optimizer}_step_distance(
_state_previous, _state_current) < _args.tolerance
"""):
break
return iterationCtrl.iteration
# ========================================================================
def logregr_train(schema_madlib, tbl_source, tbl_output, dep_col, ind_col,
grouping_col, max_iter, optimizer, tolerance, **kwargs):
"""
Train logistic model
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param tbl_source Name of relation containing the training data
@param tbl_output Name of relation where model will be outputted
@param dep_col Name of dependent column in training data (of type BOOLEAN)
@param ind_col Name of independent column in training data (of type
DOUBLE PRECISION[])
@param grouping_col List of column names on which to group the data
@param max_iter The maximum number of iterations that are allowed.
@param optimizer Name of the optimizer. 'newton' or 'irls': Iteratively
reweighted least squares, 'cg': conjugate gradient or 'igd':
incremental gradient descent
@param tolerance The precision that the results should have
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return A composite value which is __logregr_result defined in logistic.sql_in
"""
optimizer = __logregr_validate_args(schema_madlib,
tbl_source, tbl_output,
dep_col, ind_col,
grouping_col, max_iter,
optimizer, tolerance)
return __logregr_train_compute(schema_madlib, tbl_source, tbl_output,
dep_col, ind_col, grouping_col, max_iter,
optimizer, tolerance, **kwargs)
# ========================================================================
def __logregr_validate_args(schema_madlib, tbl_source, tbl_output, dep_col,
ind_col, grouping_col, max_iter, optimizer,
tolerance):
"""
Validate the arguments
"""
if not tbl_source or tbl_source in ('null', '') or \
(not table_exists(tbl_source)):
plpy.error("Logregr error: Data table does not exist!")
if (table_exists(tbl_output)):
plpy.error("Output table name already exists. Drop the table before calling the function.")
if table_is_empty(tbl_source):
plpy.error("Logregr error: Data table is empty!")
if tbl_output.lower() in ('null', ''):
plpy.error("Logregr error: Invalid output table name!")
# if not columns_exist_in_table(tbl_source, [dep_col]):
# plpy.error("Logregr error: Dependent column does not exist!")
if not dep_col or dep_col.lower() in ('null', ''):
plpy.error("Logregr error: Invalid dependent column name!")
if not scalar_col_has_no_null(tbl_source, dep_col):
plpy.error("Logregr error: Dependent variable has Null values! \
Please filter out Null values before using this function!")
if not ind_col or ind_col.lower() in ('null', ''):
plpy.error("Logregr error: Invalid independent column name!")
if grouping_col and grouping_col.lower() in ('null', ''):
# grouping_col is optional but if provided should be valid column name
plpy.error("Logregr error: Invalid grouping columns name!")
if grouping_col:
if not columns_exist_in_table(tbl_source,
_string_to_array(grouping_col), schema_madlib):
plpy.error("Logregr error: Grouping column does not exist!")
if max_iter <= 0:
plpy.error("Logregr error: Maximum number of iterations must be positive!")
if tolerance < 0:
plpy.error("Logregr error: The tolerance cannot be negative!")
if optimizer == "newton":
optimizer = "irls"
elif optimizer not in ("irls", "cg", "igd"):
plpy.error(""" Logregr error: Unknown optimizer requested.
Must be 'newton'/'irls', 'cg', or 'igd'.
""")
return optimizer
# ========================================================================
def __logregr_train_compute(schema_madlib, tbl_source, tbl_output, dep_col,
ind_col, grouping_col, max_iter, optimizer,
tolerance, verbose, **kwargs):
"""
Create an output table (drop if exists) that contains the logistic
regression model
"""
old_msg_level = plpy.execute("select setting from pg_settings where \
name='client_min_messages'")[0]['setting']
if verbose:
plpy.execute("set client_min_messages to warning")
else:
plpy.execute("set client_min_messages to error")
args = dict(schema_madlib = schema_madlib,
tbl_source = tbl_source,
tbl_output = tbl_output,
dep_col = dep_col,
ind_col = ind_col,
max_iter = max_iter,
optimizer = optimizer,
tolerance = tolerance,
tbl_logregr_args = __unique_string(),
tbl_logregr_state = __unique_string(),
irls = "__logregr_irls_result",
newton = "__logregr_irls_result",
cg = "__logregr_cg_result",
igd = "__logregr_igd_result")
plpy.execute("select {schema_madlib}.create_schema_pg_temp()".format(**args))
plpy.execute(
"""
drop table if exists pg_temp.{tbl_logregr_args};
create table pg_temp.{tbl_logregr_args} as
select
{max_iter} as max_iter,
{tolerance} as tolerance
""".format(**args))
# return an array of dict
# each dict has two elements: iteration number, and grouping value array
if grouping_col:
grouping_list = _string_to_array(grouping_col)
for i in range(len(grouping_list)):
grouping_list[i] += "::text"
grouping_str = ','.join(grouping_list)
else:
grouping_str = "Null"
iteration_run = __compute_logregr(schema_madlib, args["tbl_logregr_args"],
args["tbl_logregr_state"], tbl_source,
dep_col, ind_col, optimizer,
grouping_col = grouping_col,
grouping_str = grouping_str)
grouping_str1 = "" if grouping_col is None else grouping_col + ","
grouping_str2 = "1 = 1" if grouping_col is None else grouping_col
plpy.execute(
"""
drop table if exists {tbl_output};
create table {tbl_output} as
select
{grouping_str1}
(case when (result).status = 2 then NULL::double precision[]
else (result).coef end) as coef,
(case when (result).status = 2 then NULL::double precision
else (result).log_likelihood end) as log_likelihood,
(case when (result).status = 2 then NULL::double precision[]
else (result).std_err end) as std_err,
(case when (result).status = 2 then NULL::double precision[]
else (result).z_stats end) as z_stats,
(case when (result).status = 2 then NULL::double precision[]
else (result).p_values end) as p_values,
(case when (result).status = 2 then NULL::double precision[]
else (result).odds_ratios end) as odds_ratios,
(case when (result).status = 2 then NULL::double precision
else (result).condition_no end) as condition_no,
_iteration as num_iterations
from
(select
{grouping_str1}
{schema_madlib}.{fnName}(_state) as result,
_iteration
from
{tbl_logregr_state}) t
join
(
select
{grouping_str1}
max(_iteration) as _iteration
from {tbl_logregr_state}
group by {grouping_str2}
) s
using ({grouping_str1} _iteration)
""".format(grouping_str1 = grouping_str1,
grouping_str2 = grouping_str2,
fnName = args[args["optimizer"]],
iteration_run = iteration_run,
**args))
failed_groups = plpy.execute(
"""
select count(*) as count
from {tbl_output}
where coef is Null
""".format(**args))[0]["count"]
all_groups = plpy.execute(
"""
select count(*) as count
from {tbl_output}
""".format(**args))[0]["count"]
if grouping_col:
plpy.info(str(all_groups - failed_groups) +
" groups succesfully passed, and " +
str(failed_groups) + " groups failed")
plpy.execute("""
drop table if exists pg_temp.{tbl_logregr_args};
drop table if exists pg_temp.{tbl_logregr_state}
""".format(**args))
plpy.execute("set client_min_messages to " + old_msg_level)
return None