blob: 86f0c175dcf3c8ea9e7fcd018a69e18be3bd5b02 [file] [log] [blame]
"""
@file in_mem_group_control.py_in
@brief in-memory grouping controller classes
"""
import plpy
import math
from control import MinWarning
from utilities import unique_string
from collections import namedtuple
from collections import Iterable
class BaseState(object):
"""@brief Abstraction for intermediate iteration state"""
def __init__(self, **kwargs):
self._state = {}
self._is_none = None
self.initialize(**kwargs)
def __len__(self):
return len(self._state)
def __del__(self):
del self._state
def __getitem__(self, k):
return self._state[k]
def __setitem__(self, k, v):
self._state[k] = v
@property
def keys(self):
return self._state.keys()
@property
def values(self):
if self.is_none():
return []
return [s for x in self._state.values() for s in x]
def delete(self, keys_to_remove):
for k in keys_to_remove:
try:
del self._state[k]
except KeyError:
pass
self._is_none = None
def initialize(self,
col_grp_key='',
col_grp_state='',
ret_states=None, **kwargs):
self.update(col_grp_key, col_grp_state, ret_states)
def isclose_to_zero(self, a, b=0.0, rel_tol=1e-09, abs_tol=0.0):
'''
A variation of of Python 3.5 math.isclose()
https://hg.python.org/cpython/file/tip/Modules/mathmodule.c#l1993
'''
# sanity check on the inputs
if rel_tol < 0 or abs_tol < 0:
raise ValueError("tolerances must be non-negative")
# short circuit exact equality -- needed to catch two infinities of
# the same sign. And perhaps speeds things up a bit sometimes.
if a == b:
return True
# This catches the case of two infinities of opposite sign, or
# one infinity and one finite number. Two infinities of opposite
# sign would otherwise have an infinite relative tolerance.
# Two infinities of the same sign are caught by the equality check
# above.
if math.isinf(a) or math.isinf(b):
return False
# now do the regular computation
# this is essentially the "weak" test from the Boost library
diff = math.fabs(b - a)
result = (((diff <= math.fabs(rel_tol * b)) or
(diff <= math.fabs(rel_tol * a))) or
(diff <= abs_tol))
return result
def are_last_state_value_zero(self):
# This function returns a boolean value, after inspecting the last
# element of the state array for each group.
# This returns True, only if the last element of the state array
# of every group is (almost equal to) 0. If the value is non-zero
# even for a single group, it returns False.
return all([self.isclose_to_zero(val[-1]) for val in self._state.itervalues()])
def update(self, col_grp_key, col_grp_state, ret_states):
failed_grp_keys = []
if not ret_states:
return failed_grp_keys
t0 = ret_states[0]
# no key column in table ret_states
if col_grp_key not in t0:
return failed_grp_keys
# initialize state to None
if col_grp_state == '':
self._is_none = True
for s in ret_states:
self._state[s[col_grp_key]] = None
return failed_grp_keys
for t in ret_states:
_grp_key, _grp_state = t[col_grp_key], t[col_grp_state]
if _grp_state is None:
failed_grp_keys.append(_grp_key)
else:
self._state[_grp_key] = _grp_state
# no need to update if all failed
if len(failed_grp_keys) < len(self):
self._is_none = False
return failed_grp_keys
# entries in self not in other are kept
def update_from_state(self, other, keys=None):
if not isinstance(other, BaseState):
return
if keys is None:
self._state.update(other._state)
else:
for k in keys:
self[k] = other[k]
# reset cache
self._is_none = None
def sync_from(self, other):
self._state = {}
self.update_from_state(other)
self._is_none = other.is_none()
def is_none(self):
if self._is_none is None:
self._is_none = True
for k, v in self._state.iteritems():
if v is not None:
self._is_none = False
break
return self._is_none
def interpret(self, schema_madlib, state_type, keys=None):
if keys is None:
keys = self.keys
elif isinstance(keys, str) or not isinstance(keys, Iterable):
keys = [keys]
s = dict.fromkeys(keys)
plan = plpy.prepare(
"""
SELECT
(result).loss AS loss,
(result).norm_of_gradient AS norm_of_gradient
FROM (
SELECT
{schema_madlib}.internal_linear_svm_igd_result($1)
AS result
) subq
""".format(schema_madlib=schema_madlib), [state_type])
for k in keys:
s[k] = plpy.execute(plan, [self._state[k]])[0]
return s
class Bytea8State(BaseState):
"""@brief bytea8 type state"""
def __init__(self, **kwargs):
super(Bytea8State, self).__init__(**kwargs)
@property
def values(self):
return self._state.values()
def state_factory(is_bytea8, **kwargs):
if is_bytea8:
return Bytea8State(**kwargs)
else:
return BaseState(**kwargs)
class GroupIterationController:
"""
@brief Abstraction for implementing in-memory iteration controller for
SQL aggregate with STYPE=madlib.bytea8 (e.g. DynamicStruct) in PL/Python
"""
def __init__(self, arg_dict):
"""
arg_dict: Dictionary containing arguments to be defined by the calling function:
Necessary:
state_type: Type of the transition state
(can be double precision[] or <schema_madlib>.bytea8)
col_dep_var: Name of the dependent column name
col_ind_var: Name of the independent column name
Optional:
col_n_tuples: Name of the column containing count of tuples.
Set to a unique string if not defined
"""
self.schema_madlib = arg_dict['schema_madlib']
self.in_with = False
self.iteration = -1
self.is_group_null = True if arg_dict["grouping_col"] is None else False
self.verbose = arg_dict.get('verbose', False)
self.kwargs = dict(arg_dict)
self.kwargs.update(
as_rel_source=arg_dict.get('as_rel_source', '_src'),
state_type=arg_dict.get('state_type', 'double precision[]').format(**arg_dict),
col_grp_null=unique_string(desp='col_grp_null'),
col_n_tuples=self.kwargs.get('col_n_tuples',
unique_string(desp='col_n_tuples')),
col_grp_key=self.kwargs.get('col_grp_key',
unique_string(desp='col_grp_key')),
grouping_col=("NULL"
if arg_dict["grouping_col"] is None
else arg_dict["grouping_col"]),
)
self.grp_to_n_tuples = {}
self.failed_grp_keys = []
self.is_state_type_bytea8 = False
if self.kwargs['state_type'] == "{0}.bytea8".format(self.schema_madlib):
self.is_state_type_bytea8 = True
elif (self.kwargs['state_type'].lower() == "double precision[]" or
self.kwargs['state_type'].lower() == "float8[]"):
self.is_state_type_bytea8 = False
else:
plpy.error("Internal error: unexpected state type!")
self.new_states = state_factory(self.is_state_type_bytea8)
self.old_states = state_factory(self.is_state_type_bytea8)
self.finished_states = state_factory(self.is_state_type_bytea8)
self.group_param = self._init_group_param()
def _init_group_param(self):
_grp_key = ("array_to_string(ARRAY[{grouping_str}], ',')"
.format(grouping_str=self.kwargs['grouping_str']))
_select_rel_state = ("SELECT "
"grp_key AS {col_grp_key},"
"state AS {col_grp_state} "
"FROM {schema_madlib}._gen_state($1, NULL, $2)"
.format(**self.kwargs))
_select_n_tuples = ("SELECT "
"unnest($3) AS {col_grp_key}, "
"unnest($4) AS {col_n_tuples}"
.format(**self.kwargs))
_using_str = "ON TRUE"
_grouped_state_type = "float8[]"
_groupby_str = ""
if not self.is_group_null:
_groupby_str = "GROUP BY {grouping_col}, {col_grp_key}".format(
**self.kwargs)
_using_str = "USING ({col_grp_key})".format(**self.kwargs)
_grp_key = self.kwargs['col_grp_key']
if self.is_state_type_bytea8:
_select_rel_state = ("SELECT "
"unnest($1) AS {col_grp_key}, "
"unnest($2) AS {col_grp_state}"
"".format(**self.kwargs))
_grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs)
GroupParam = namedtuple('GroupParam',
'groupby_str, using_str,'
'select_rel_state,grouped_state_type,'
'grp_key, select_n_tuples')
return GroupParam(groupby_str=_groupby_str,
using_str=_using_str,
select_rel_state=_select_rel_state,
grp_key=_grp_key,
select_n_tuples=_select_n_tuples,
grouped_state_type=_grouped_state_type)
def __enter__(self):
verbosity_level = "info" if self.verbose else "error"
with MinWarning(verbosity_level):
############################
# create state table
# currently assuming that groups is passed as a valid array
group_col = ("NULL::integer as {col_grp_null}" if self.is_group_null
else "{grouping_col}").format(**self.kwargs)
groupby_str = ("{col_grp_null}" if self.is_group_null
else "{grouping_col}").format(**self.kwargs)
plpy.execute("""
DROP TABLE IF EXISTS {rel_state};
CREATE TEMPORARY TABLE {rel_state} AS (
SELECT
array_to_string(ARRAY[{grouping_str}], ',') AS {col_grp_key},
0::integer AS {col_grp_iteration},
NULL::{state_type} AS {col_grp_state},
count(*) AS {col_n_tuples},
{group_col}
FROM {rel_source}
WHERE ({col_dep_var}) IS NOT NULL
AND NOT {schema_madlib}.array_contains_null({col_ind_var})
GROUP BY {groupby_str}
);
""".format(group_col=group_col,
groupby_str=groupby_str,
**self.kwargs))
############################
# checking null in group values
# We cannot allow NULL due to array_to_string cannot handle it well.
if not self.is_group_null:
null_test = (" OR ".join([g.strip() + " is NULL" for g in
self.kwargs['grouping_col'].split(",")]))
null_count = plpy.execute("""
SELECT count(*) FROM {rel_state} WHERE {null_test}
""".format(null_test=null_test,
rel_state=self.kwargs['rel_state']))[0]['count']
if null_count != 0:
plpy.error("Grouping error: at least one of the grouping "
"columns contains NULL values! Please filter "
"out those NULL values.")
############################
# initialize states
rel_state_str = self.kwargs['rel_state']
col_grp_key = self.kwargs['col_grp_key']
col_n_tuples = self.kwargs['col_n_tuples']
ret_states = plpy.execute("SELECT * FROM " + rel_state_str)
self.new_states.initialize(col_grp_key=col_grp_key,
col_grp_state='',
ret_states=ret_states)
for s in ret_states:
self.grp_to_n_tuples[s[col_grp_key]] = long(s[col_n_tuples])
self.in_with = True
return self
def are_last_state_value_zero(self):
return self.new_states.are_last_state_value_zero()
def info(self):
""" Logging intermediate state information """
if not self.verbose:
return
group_param = self.group_param
schema_madlib = self.kwargs['schema_madlib']
res = self.new_states.interpret(schema_madlib,
group_param.grouped_state_type)
for grp, t in res.iteritems():
loss, normg = t['loss'], t['norm_of_gradient']
iteration = self.iteration
output_str = "DEBUG: \
grp = {grp:10s}, \
iter = {iteration:5d}, \
loss = {loss:.5e}, \
|gradient| = {normg:.5e}, \
stepsize = {stepsize:.5e}"
plpy.notice(output_str.format(
grp=grp, iteration=iteration,
loss=loss, normg=normg,
**self.kwargs))
def final(self):
""" Store the final converged state to a table for output """
group_param = self.group_param
insert_sql = """
INSERT INTO {rel_state}
SELECT
{col_grp_key},
{iteration}::int,
{col_grp_state},
{col_n_tuples}::bigint,
{grouping_col}
FROM (
SELECT {grouping_col}, {col_grp_key}
FROM {rel_state}
) AS _src
JOIN ( {select_rel_state} ) AS _rel_state
USING ({col_grp_key})
JOIN ( {select_n_tuples} ) AS _rel_n_tuples
USING ({col_grp_key})
""".format(
iteration=self.iteration,
select_rel_state=group_param.select_rel_state,
select_n_tuples=group_param.select_n_tuples,
**self.kwargs)
insert_plan = plpy.prepare(insert_sql,
["text[]", group_param.grouped_state_type,
"text[]", "bigint[]"])
plpy.execute(insert_plan, [self.finished_states.keys,
self.finished_states.values,
self.grp_to_n_tuples.keys(),
self.grp_to_n_tuples.values()])
if self.failed_grp_keys:
plpy.execute(insert_plan,
[self.failed_grp_keys,
[],
self.grp_to_n_tuples.keys(),
self.grp_to_n_tuples.values()])
def __exit__(self, type, value, tb):
self.in_with = False
def get_state_size(self):
"""
Return the size of the state. Greenplum does not have the
array_length() method yet, hence this work-around.
"""
if self.is_state_type_bytea8:
unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs)
grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs)
else:
unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs)
grouped_state_type = "float8[]"
eval_plan = plpy.prepare("""
SELECT _state_current
FROM (
SELECT {unnest_str_current}
) subq1
""".format(unnest_str_current=unnest_str_current, **self.kwargs),
["text[]", grouped_state_type] * 2)
state_size = len(plpy.execute(eval_plan,
[self.old_states.keys,
self.old_states.values,
self.new_states.keys,
self.new_states.values])[0]['_state_current'])
return state_size
def get_param_value_per_group(self, select_col):
"""
Return desired values from the current state.
Example:
to return the value in the last index of the state, the 'select_col'
can be set to '_state_current[array_length(_state_current,1)] AS last'
This is especially useful when there is grouping and we want to find
some specific state value (such as loss) for each group.
"""
# Check convex/mlp_igd.py_in for an example of how this is useful.
if self.is_state_type_bytea8:
unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs)
grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs)
else:
unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs)
grouped_state_type = "float8[]"
eval_plan = plpy.prepare("""
SELECT
{select_col},
{col_grp_key}
FROM (
SELECT {unnest_str_current}
) subq1
""".format(select_col=select_col,
unnest_str_current=unnest_str_current, **self.kwargs),
["text[]", grouped_state_type] * 2)
res = plpy.execute(eval_plan,
[self.old_states.keys,
self.old_states.values,
self.new_states.keys,
self.new_states.values])
return res
def test(self, condition):
"""
Return True if the given expression is TRUE for all in-progress groups.
Also returns True if there are no groups currently in-progress.
Evaluate the given expression for all in-progress groups.
Move groups with expression as True from self.new_states to
self.finished_states. The expression may depend on the current
inter-iteration state and all arguments
@param condition SQL boolean expression. The
following names are defined and can be used in the condition:
- \c _args - The (single-row) argument table
- \c _state - In memory inter-iteration state
@return True if \c expression in all non-failed groups is True,
otherwise False
"""
if len(self.new_states) == 0:
# self.new_states can become empty if the last of the groups failed
# in the previous update
return True
if self.is_state_type_bytea8:
unnest_str_previous = "unnest($1) AS {col_grp_key}, unnest($2) AS _state_previous".format(**self.kwargs)
unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs)
grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs)
else:
unnest_str_previous = "grp_key AS {col_grp_key}, state AS _state_previous FROM {schema_madlib}._gen_state($1, NULL, $2)".format(**self.kwargs)
unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs)
grouped_state_type = "float8[]"
condition = condition.format(iteration=self.iteration, **self.kwargs)
eval_plan = plpy.prepare("""
SELECT
CAST(({condition}) AS BOOLEAN) AS _expression,
{col_grp_key}
FROM
(
(
SELECT {unnest_str_previous}
) sub1
JOIN
(
SELECT {unnest_str_current}
) sub2
USING ({col_grp_key})
) subq1
""".format(condition=condition,
unnest_str_current=unnest_str_current,
unnest_str_previous=unnest_str_previous, **self.kwargs),
["text[]", grouped_state_type] * 2)
res = plpy.execute(eval_plan,
[self.old_states.keys,
self.old_states.values,
self.new_states.keys,
self.new_states.values])
col_grp_key = self.kwargs['col_grp_key']
finished_keys = [t[col_grp_key] for t in res if t['_expression']]
self.finished_states.update_from_state(self.new_states, finished_keys)
self.new_states.delete(finished_keys)
return len(self.new_states) == 0
def update(self, newState, **updateKwargs):
"""
Update the inter-iteration state
@param newState SQL expression of (or returning) type
<tt>state_type.kwargs.state_type</tt>. The
following names are defined and can be used in the condition:
- \c _args - The (single-row) argument table
- \c _state - In memory inter-iteration state
.
Note that <tt>{iteration}</tt> will still be the current iteration.
For instance, it could be used in the expression as a WHERE
condition: <tt>[...] WHERE _state._iteration = {iteration}</tt>
This updates the current inter-iteration state to the result of
evaluating \c newState.
"""
newState = newState.format(**self.kwargs)
self.iteration = self.iteration + 1
group_param = self.group_param
run_sql = """
SELECT
{_grp_key} AS {col_grp_key},
{grouping_col},
{iteration} AS {col_grp_iteration},
({newState}) AS {col_grp_state}
FROM (
SELECT *,
array_to_string(ARRAY[{grouping_str}], ',') AS {col_grp_key}
FROM {rel_source}
) AS {as_rel_source}
JOIN ( {select_rel_state} ) AS {rel_state}
{using_str}
JOIN ( {select_n_tuples} ) AS _rel_n_tuples
{using_str}
{groupby_str}
""".format(
newState=newState,
iteration=self.iteration,
using_str=group_param.using_str,
groupby_str=group_param.groupby_str,
_grp_key=group_param.grp_key,
select_rel_state=group_param.select_rel_state,
select_n_tuples=group_param.select_n_tuples,
**self.kwargs)
update_plan = plpy.prepare(run_sql,
["text[]", group_param.grouped_state_type,
"text[]", "integer[]"])
res_tuples = plpy.execute(update_plan, [self.new_states.keys,
self.new_states.values,
self.grp_to_n_tuples.keys(),
self.grp_to_n_tuples.values()])
col_grp_state = self.kwargs['col_grp_state']
col_grp_key = self.kwargs['col_grp_key']
self.old_states.sync_from(self.new_states)
self.failed_grp_keys.extend(self.new_states.update(
col_grp_key,
col_grp_state,
res_tuples))