blob: 82cdf1e587e97e7f7437cdc3f07c61eda0c9f77a [file] [log] [blame]
import plpy
import math
from utilities.validate_args import input_tbl_valid, output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.utilities import _assert
from utilities.utilities import py_list_to_sql_string
class Summarizer:
def __init__(self, schema_madlib, source_table, output_table,
target_cols, grouping_cols, distinctify, get_quartiles,
xtileify, ntile_array, how_many_mfv, get_mfv_quick,
n_cols_per_run):
self._schema_madlib = schema_madlib
self._source_table = source_table
self._output_table = output_table
self._grouping_cols = grouping_cols
self._target_cols = target_cols
self._distinctify = distinctify
self._get_quartiles = get_quartiles
self._xtileify = xtileify if xtileify else 'Exact'
self._ntile_array = ntile_array
self._how_many_mfv = how_many_mfv if how_many_mfv is not None else 10
self._get_mfv_quick = get_mfv_quick if get_mfv_quick is not None else False
self._n_cols_per_run = n_cols_per_run if n_cols_per_run is not None else 15
self._columns = None
self._column_names = None
self._delimiter = '_.*.&.!.!.&.*_'
self._z_score = 1.96
def _populate_columns(self):
if self._target_cols:
lower_target_cols = [each_col.lower() for each_col in self._target_cols
if '"' not in each_col]
target_selection = "AND attname = ANY({0})".format(
py_list_to_sql_string(lower_target_cols, array_type="text"))
else:
target_selection = ""
self._columns = plpy.execute("""
SELECT
quote_ident(attname) AS attname,
typname,
attnum
FROM
pg_attribute a
JOIN
pg_type t
ON (a.atttypid=t.oid)
WHERE attrelid = '{tbl}'::regclass
AND attnum > 0
AND not attisdropped
{target_selection}
ORDER BY attnum
""".format(tbl=self._source_table,
target_selection=target_selection))
self._column_names = [col['attname'] for col in self._columns]
# -----------------------------------------------------------------------
# Argument validation functions
# -----------------------------------------------------------------------
def _validate_table_names(self):
"""
Validate the required arguments for the summary function
"""
# source table
_assert(self._source_table is not None and self._source_table.strip(),
"Summary error: Invalid data table name!")
input_tbl_valid(self._source_table, "Summary")
output_tbl_valid(self._output_table, "Summary")
def _validate_required_cols(self, required_cols):
if required_cols is not None and required_cols != [None]:
clean_cols = [c for c in required_cols if c is not None]
cols_in_tbl_valid(self._source_table, clean_cols, "Summary")
def _validate_ntile_array(self):
if self._ntile_array is not None:
_assert(len(self._ntile_array) > 0,
"Summary - Invalid parameter: ntile_array is empty.")
for ntile in self._ntile_array:
_assert(0 < ntile < 1.0,
"Summary - Invalid parameter: Values in ntile_array "
"should be in the range [0.0, 1.0]")
def _adjust_cols(self):
# if #cols == 1, then it should not appear in the grouping_cols
if (len(self._column_names) == 1 and
self._column_names[0] in self._grouping_cols):
self._grouping_cols.remove(self._column_names[0])
def _validate_params(self):
"""
Validate all parameters in the class
"""
self._validate_table_names()
self._validate_required_cols(self._target_cols)
self._validate_required_cols(self._grouping_cols)
self._validate_ntile_array()
_assert(self._how_many_mfv is not None and self._how_many_mfv > 0,
"Summary - Invalid parameter: Number of most frequent values"
" required should be positive")
self._how_many_mfv = int(self._how_many_mfv)
_assert(self._n_cols_per_run is not None and self._n_cols_per_run > 0,
"Summary - Invalid parameter: Number of columns per run"
" should be positive")
self._n_cols_per_run = int(self._n_cols_per_run)
# ----- End of argument validation functions -----------------------------
def _build_subquery(self, group_var, cols):
"""
Returns a subquery of statistics for target columns for a single
grouping variable
"""
# ensure group_var does not suffer from bad input from user
group_var = group_var.lower() if group_var and '"' not in group_var else group_var
args = {'source_table': self._source_table}
# Exclude the grouping_cols variable from the list of columns to
# report statistics on
cols = filter(lambda x: x['attname'] != group_var, cols)
if group_var:
args['group_value'] = "{schema_madlib}.__to_char(%s)" % group_var
args['group_var'] = "'%s'" % group_var
args['group_expr'] = "\n GROUP BY %s" % group_var
else:
args['group_value'] = "NULL"
args['group_var'] = "NULL"
args['group_expr'] = ""
args['column_names'] = ','.join(["'%s'" % c['attname'] for c in cols])
args['column_types'] = ','.join(["'%s'" % c['typname'] for c in cols])
args['column_number'] = ','.join([str(c['attnum']) for c in cols])
if self._distinctify is 'Estimated':
args['distinct_columns'] = ','.join(["{schema_madlib}.fmsketch_dcount(%s)" % c['attname'] for c in cols])
elif self._distinctify is 'Exact':
args['distinct_columns'] = ','.join(["count(distinct %s)" % c['attname'] for c in cols])
else:
args['distinct_columns'] = ','.join(["NULL" for c in cols])
args['missing_columns'] = ','.join(["sum(case when %s is null then 1 else 0 end)" % (
c['attname']) for c in cols])
args['blank_columns'] = ','.join(["sum(case when {0} similar to E'\\\\W*' \
then 1 else 0 end)".format(c['attname'])
if c['typname'] in ('varchar', 'bpchar', 'text', 'character varying')
else 'NULL' for c in cols])
# ------ Helper sub-functions ------
# types are obtained from pg_attribute (hence different from those
# obtained from information_schema.columns)
numeric_types = ('int2', 'int4', 'int8', 'float4', 'float8', 'numeric')
def numeric_type(operator, datatype):
if datatype['typname'] in numeric_types:
return '%s(%s)' % (operator, datatype['attname'])
return "NULL"
def minmax_type(minmax, c):
if c['typname'] in numeric_types:
return '%s(%s)' % (minmax, c['attname'])
if c['typname'] in ('varchar', 'bpchar', 'text'):
return "%s(length(%s))" % (minmax, c['attname'])
return "NULL"
def xtile_type(xtile, c):
if self._xtileify is 'Exact':
if c['typname'] in numeric_types:
return "percentile_cont(%s) WITHIN GROUP (ORDER BY %s)" % (xtile, c['attname'])
if self._xtileify is 'Estimated':
return "NULL" # NOT YET IMPLEMENTED
return "NULL"
def mfv_type(get_count, c):
slicing = ('0:0', '1:1')[get_count]
mfv_method = ('mfvsketch_top_histogram',
'mfvsketch_quick_histogram')[self._get_mfv_quick]
return """
array_to_string(({schema_madlib}.{mfv_method}(
{column},{topk}))[0:{topk}-1][{slice}],
'{delimiter}')""".format(
schema_madlib=self._schema_madlib,
mfv_method=mfv_method,
column=c['attname'],
topk=self._how_many_mfv,
slice=slicing,
delimiter=self._delimiter)
# ------ End of Helper sub-functions ------
args['mean_columns'] = ','.join([numeric_type('avg', c) for c in cols])
args['var_columns'] = ','.join([numeric_type('variance', c) for c in cols])
args['min_columns'] = ','.join([minmax_type('min', c) for c in cols])
args['q1_columns'] = ','.join([xtile_type(0.25, c)
if self._get_quartiles
else 'NULL' for c in cols])
args['q2_columns'] = ','.join([xtile_type(0.50, c)
if self._get_quartiles
else 'NULL' for c in cols])
args['q3_columns'] = ','.join([xtile_type(0.75, c)
if self._get_quartiles
else 'NULL' for c in cols])
args['max_columns'] = ','.join([minmax_type('max', c) for c in cols])
args['ntile_columns'] = "array_to_string(array[NULL], ',')"
pos_str = "sum(case when {0} > 0 then 1 else 0 end)"
args['positive_columns'] = ','.join([pos_str.format(c['attname'])
if c['typname'] in numeric_types
else 'NULL' for c in cols])
neg_str = "sum(case when {0} < 0 then 1 else 0 end)"
args["negative_columns"] = ','.join([neg_str.format(c['attname'])
if c['typname'] in numeric_types
else 'NULL' for c in cols])
zero_str = "sum(case when {0} = 0 then 1 else 0 end)"
args["zero_columns"] = ','.join([zero_str.format(c['attname'])
if c['typname'] in numeric_types
else 'NULL' for c in cols])
if self._ntile_array:
args['ntile_columns'] = ",".join([
"array_to_string(array[" +
",".join([xtile_type(xtile, c) for xtile in self._ntile_array]) +
"], ',')" for c in cols])
args['mfv_value'] = ','.join([mfv_type(False, c) for c in cols])
args['mfv_count'] = ','.join([mfv_type(True, c) for c in cols])
subquery = """
SELECT
{group_var}::text as group_by,
{group_value}::text as group_by_value,
array[{column_names}]::text[] as target_column,
array[{column_types}]::text[] as datatype,
array[{column_number}]::integer[] as colnum,
count(*)::bigint as rowcount,
array[{positive_columns}]::bigint[] as positive_values,
array[{negative_columns}]::bigint[] as negative_values,
array[{zero_columns}]::bigint[] as zero_values,
array[{mean_columns}]::float8[] as mean,
array[{var_columns}]::float8[] as variance,
array[{distinct_columns}]::bigint[] as distinct_values,
array[{missing_columns}]::bigint[] as missing_values,
array[{blank_columns}]::bigint[] as blank_values,
array[{min_columns}]::float8[] as min,
array[{max_columns}]::float8[] as max,
array[{q1_columns}]::float8[] as first_quartile,
array[{q2_columns}]::float8[] as median,
array[{q3_columns}]::float8[] as third_quartile,
array[{ntile_columns}]::text[] as ntiles,
array[{mfv_value}]::text[] as mfv_value,
array[{mfv_count}]::text[] as mfv_count
FROM {source_table}{group_expr}
""".format(**args).format(schema_madlib=self._schema_madlib)
return subquery
def _build_inner_query(self, group_val, cols):
subquery = self._build_subquery(group_val, cols)
query = """
SELECT
group_by,
group_by_value,
target_column,
datatype,
colnum,
rowcount,
distinct_values,
missing_values,
blank_values,
distinct_values::float8 / rowcount as fraction_distinct_values,
missing_values::float8 / rowcount as fraction_missing_values,
blank_values::float8 / rowcount as fraction_blank_values,
positive_values,
negative_values,
zero_values,
mean,
variance,
(mean - {z_score} * sqrt(variance / rowcount))::float8 as ci_lower_bound,
(mean + {z_score} * sqrt(variance / rowcount))::float8 as ci_upper_bound,
min,
max,
first_quartile,
median,
third_quartile,
ntiles,
string_to_array(mfv_value, '{delimiter}') as mfv_value,
string_to_array(mfv_count, '{delimiter}') as mfv_count
FROM
(
SELECT
group_by,
group_by_value,
unnest(target_column) AS target_column,
unnest(datatype) AS datatype,
unnest(colnum) AS colnum,
rowcount,
unnest(distinct_values) as distinct_values,
unnest(missing_values) as missing_values,
unnest(blank_values) as blank_values,
unnest(positive_values) as positive_values,
unnest(negative_values) as negative_values,
unnest(zero_values) as zero_values,
unnest(mean) as mean,
unnest(variance) as variance,
unnest(min) as min,
unnest(max) as max,
unnest(first_quartile) as first_quartile,
unnest(median) as median,
unnest(third_quartile) as third_quartile,
string_to_array(unnest(ntiles), ',') as ntiles,
unnest(mfv_value) as mfv_value,
unnest(mfv_count) as mfv_count
FROM ({subquery}) q1
) q2
""".format(schema_madlib=self._schema_madlib, subquery=subquery,
delimiter=self._delimiter, z_score=self._z_score)
return query
def _build_query(self, group_val, cols, create_table):
query = self._build_inner_query(group_val, cols)
distinct_values = ''
if self._distinctify != 'Skip':
distinct_values = """
(CASE WHEN rowcount < distinct_values THEN rowcount
ELSE distinct_values
END) AS distinct_values,"""
first_quartile = ''
median = ''
third_quartile = ''
if self._get_quartiles and self._xtileify != 'Skip':
first_quartile = 'first_quartile,'
median = 'median,'
third_quartile = 'third_quartile,'
ntiles = ''
if self._ntile_array and self._xtileify != 'Skip':
ntiles = 'ntiles::float8[] as quantile_array,'
if create_table:
final_query = 'CREATE TABLE {output_table} AS '
else:
final_query = 'INSERT INTO {output_table} '
final_query += """
SELECT
group_by as group_by,
group_by_value as group_by_value,
target_column as target_column,
colnum as column_number,
datatype as data_type,
rowcount as row_count,
{distinct_values}
missing_values as missing_values,
blank_values as blank_values,
(missing_values::float8 / rowcount) as fraction_missing,
(blank_values::float8 / rowcount) as fraction_blank,
positive_values,
negative_values,
zero_values,
mean,
variance,
CASE WHEN ci_lower_bound is NULL OR ci_upper_bound is NULL
THEN null
ELSE ARRAY[ci_lower_bound, ci_upper_bound]
END as confidence_interval,
min,
max,
{first_quartile}
{median}
{third_quartile}
{ntiles}
mfv_value as most_frequent_values,
mfv_count::bigint[] as mfv_frequencies
FROM
(
{query}
) q3
"""
if create_table:
final_query += """
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (group_by)')"""
return final_query.format(
output_table=self._output_table,
query=query,
distinct_values=distinct_values,
first_quartile=first_quartile,
median=median,
third_quartile=third_quartile,
ntiles=ntiles)
def run(self):
self._validate_params()
self._populate_columns()
_assert(self._columns,
"Summary error: Invalid column names {0} ".format(self._target_cols))
self._adjust_cols()
create_table = True
# Set a maximum number of columns to avoid out-of-memory issues when a
# lot of columns are computed concurrently. Repeat the query multiple
# times till computation is complete for all columns.
actual_n_cols = len(self._columns)
# ensuring an even spread of columns in each repeated attempt. For eg.
# if self._n_cols_per_run = 15, to simulate 31 cols we break it down as [11, 11, 9]
# instead of [15, 15, 1]. This ensures low memory usage in each subquery
n_splits = math.ceil(float(actual_n_cols) / self._n_cols_per_run)
subset_n_cols = int(math.ceil(actual_n_cols / n_splits))
subset_columns = [self._columns[pos: pos + subset_n_cols]
for pos in range(0, actual_n_cols, subset_n_cols)]
for cols in subset_columns:
for group_val in self._grouping_cols:
# summary treats the comma-separated list of grouping_cols as
# "group by each" val in the list
plpy.execute(self._build_query(group_val, cols, create_table))
create_table = False