blob: 679d30499a3f0b14a160df7b79f703748fee207c [file] [log] [blame]
"""
@file correlation.py_in
@brief Cross-correlation function for multiple columns of a relation
@namespace correlation
"""
from time import time
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import py_list_to_sql_string
from utilities.utilities import split_quoted_delimited_str
from utilities.utilities import unique_string
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import does_exclude_reserved
from utilities.validate_args import get_cols_and_types
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
def correlation(schema_madlib, source_table, output_table,
target_cols, grouping_cols, get_cov=False,
verbose=False, n_groups_per_run=10, **kwargs):
"""
Populates an output table with the coefficients of correlation between
the columns in a source table
Args:
@param schema_madlib MADlib schema namespace
@param source_table Name of input table
@param output_table Name of output table
@param target_cols Name of specific columns targetted for correlation
@param get_cov If False return the correlation matrix else
return the covariance matrix
@param verbose Flag to determine whether to output debug info
Returns:
Tuple (output table name, number of columns, time for computation)
"""
with MinWarning("info" if verbose else "error"):
function_name = 'Covariance' if get_cov else 'Correlation'
_validate_corr_arg(source_table, output_table, function_name)
_numeric_column_names, _nonnumeric_column_names = _get_numeric_columns(source_table)
_target_cols = _analyze_target_cols(source_table, target_cols,
function_name)
if n_groups_per_run is None:
n_groups_per_run = 10
# Validate grouping_cols param
if grouping_cols:
_validate_grouping_cols_param(source_table, grouping_cols,
function_name, n_groups_per_run)
if _target_cols:
# prune all non-numeric column types from target columns
_existing_target_cols = []
# we create a copy using a set to efficiently check membership (see below)
_existing_target_cols_check = set()
_nonexisting_target_cols = []
_nonnumeric_target_cols = []
_numeric_column_names = set(_numeric_column_names)
_nonnumeric_column_names = set(_nonnumeric_column_names)
for col in _target_cols:
if col in _numeric_column_names:
# efficiently check membership using the set
# ensure all column names are unique since they're output column names
if col not in _existing_target_cols_check:
_existing_target_cols.append(col)
_existing_target_cols_check.add(col)
elif col in _nonnumeric_column_names:
_nonnumeric_target_cols.append(col)
else:
_nonexisting_target_cols.append(col)
else:
# if target_cols not provided then all numeric columns are to be targeted
_existing_target_cols = _numeric_column_names[:]
_nonnumeric_target_cols = _nonnumeric_column_names[:]
_nonexisting_target_cols = []
if not _existing_target_cols:
plpy.error("Correlation error: No numeric column found in the target list.")
if len(_existing_target_cols) == 1:
plpy.error("Correlation error: Only one numeric column found in the target list.")
run_time = _populate_output_table(schema_madlib, source_table, output_table,
_existing_target_cols, grouping_cols,
n_groups_per_run, function_name, get_cov,
verbose)
# ---- Output message ----
output_text_list = ["Summary for '{0}' function".format(function_name)]
output_text_list.append("Output table = " + str(output_table))
if _nonnumeric_target_cols:
output_text_list.append("Non-numeric columns ignored: {0}".
format(','.join(_nonnumeric_target_cols)))
if _nonexisting_target_cols:
output_text_list.append("Columns that don't exist in '{0}' ignored: {1}".
format(source_table, ','.join(_nonexisting_target_cols)))
if grouping_cols:
output_text_list.append("Grouping columns: {0}".
format(grouping_cols))
output_text_list.append("Producing {0} for columns: {1}".
format(function_name.lower(), ','.join(_existing_target_cols)))
output_text_list.append("Total run time = " + str(run_time))
# ---- Output message ----
return '\n'.join(output_text_list)
# ------------------------------------------------------------------------------
def _validate_corr_arg(source_table, output_table, function_name):
"""
Validates all arguments and raises an error if there is an invalid argument
Args:
@param source_table Name of input table (string)
@param output_table Name of output table (string)
@param target_cols Comma separated list of output columns (string)
Returns:
True if all arguments are valid
"""
input_tbl_valid(source_table, function_name)
output_tbl_valid(output_table, function_name)
output_tbl_valid(add_postfix(output_table, "_summary"), function_name)
# ------------------------------------------------------------------------------
def _validate_grouping_cols_param(source_table, grouping_cols, function_name,
n_groups_per_run):
grouping_cols_list = split_quoted_delimited_str(grouping_cols)
# Column names that are used in summary table.
reserved_cols_in_summary_table = set(['method',
'source',
'output_table',
'column_names',
'mean_vector',
'total_rows_processed'])
cols_in_tbl_valid(source_table, grouping_cols_list, function_name)
does_exclude_reserved(grouping_cols_list, reserved_cols_in_summary_table)
_assert(n_groups_per_run>0, "{0}: n_groups_per_run has to be greater than 0.".
format(function_name))
def _get_numeric_columns(source_table):
"""
Returns all column names for numeric type columns in a relation
Args:
@param source_table
Returns:
List of column names in table
"""
# retrieve the numeric columns
numeric_types = ('smallint', 'integer', 'bigint',
'real', 'numeric', 'double precision')
numeric_col_names = []
nonnum_col_names = []
for col_name, col_type in get_cols_and_types(source_table):
if col_type in numeric_types:
numeric_col_names.append(col_name)
else:
nonnum_col_names.append(col_name)
return (numeric_col_names, nonnum_col_names)
# ------------------------------------------------------------------------------
def _analyze_target_cols(source_table, target_cols, function_name):
"""
Analyzes target_cols string input and converts it to a list
"""
if not target_cols or target_cols.strip() in ('', '*'):
target_cols = None
else:
target_cols = split_quoted_delimited_str(target_cols)
reserved_cols_in_output_table = set(['column_position',
'variable'])
cols_in_tbl_valid(source_table, target_cols, function_name)
does_exclude_reserved(target_cols, reserved_cols_in_output_table)
return target_cols
# ------------------------------------------------------------------------------
def _populate_output_table(schema_madlib, source_table, output_table,
col_names, grouping_cols, n_groups_per_run,
function_name, get_cov=False, verbose=False):
"""
Creates a relation with the appropriate number of columns given a list of
column names and populates with the correlation coefficients. If the table
already exists, then it is dropped before creating.
Args:
@param schema_madlib Schema of MADlib
@param source_table Name of source table
@param output_table Name of output table
@param col_names Name of all columns to place in output table
@param get_cov If False return the correlation matrix else
return covariance matrix
@param grouping_cols Name of all columns to be used for grouping
Returns:
Tuple (output table name, number of columns, time for computation)
"""
with MinWarning("info" if verbose else "error"):
start = time()
col_len = len(col_names)
col_names_as_text_array = py_list_to_sql_string(col_names, "varchar")
# Create unique strings to be used in queries.
coalesced_col_array = unique_string(desp='coalesced_col_array')
mean_col = unique_string(desp='mean')
if get_cov:
agg_str = """
(CASE WHEN count(*) > 0
THEN {0}.array_scalar_mult({0}.covariance_agg({1}, {2}),
1.0 / count(*)::double precision)
ELSE NULL
END) """.format(schema_madlib, coalesced_col_array, mean_col)
else:
agg_str = "{0}.correlation_agg({1}, {2})".format(schema_madlib,
coalesced_col_array,
mean_col)
cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col, "_avg"))
for col in col_names])
avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col, "_avg"))
for col in col_names])
avg_array = ','.join([str(add_postfix(col, "_avg")) for col in col_names])
# Create unique strings to be used in queries.
tot_cnt = unique_string(desp='tot_cnt')
cor_mat = unique_string(desp='cor_mat')
temp_output_table = unique_string(desp='temp_output')
subquery1 = unique_string(desp='subq1')
subquery2 = unique_string(desp='subq2')
grouping_cols_comma = ''
subquery_grouping_cols_comma = ''
inner_group_by = ''
# Cross join if there are no groups to consider
join_condition = ' ON (1=1) '
if grouping_cols:
group_col_list = split_quoted_delimited_str(grouping_cols)
grouping_cols_comma = add_postfix(grouping_cols, ', ')
subquery_grouping_cols_comma = get_table_qualified_col_str(
subquery2, group_col_list) + " , "
inner_group_by = " GROUP BY {0}".format(grouping_cols)
join_condition = " USING ({0})".format(grouping_cols)
create_temp_output_table_query = """
CREATE TEMP TABLE {temp_output_table} AS
SELECT
{subquery_grouping_cols_comma}
count(*) AS {tot_cnt},
{mean_col},
{agg_str} AS {cor_mat}
FROM
(
SELECT {grouping_cols_comma}
ARRAY[ {cols} ] AS {coalesced_col_array},
ARRAY [ {avg_array} ] AS {mean_col}
FROM {source_table}
JOIN
(
SELECT {grouping_cols_comma} {avgs}
FROM {source_table}
{inner_group_by}
) {subquery1}
{join_condition}
) {subquery2}
GROUP BY {grouping_cols_comma} {mean_col}
""".format(**locals())
plpy.execute(create_temp_output_table_query)
# Prepare the query for converting the matrix into the lower triangle
deconstruction_query_list = _create_deconstruction_query(schema_madlib,
col_names,
grouping_cols,
temp_output_table,
cor_mat,
n_groups_per_run)
variable_subquery = unique_string(desp='variable_subq')
matrix_subquery = unique_string(desp='matrix_subq')
# create output table
select_deconstruct_query = """
SELECT *
FROM
(
SELECT
generate_series(1, {num_cols}) AS column_position,
unnest({col_names_as_text_array}) AS variable
) {variable_subquery}
JOIN
(
{deconstruction_query}
) {matrix_subquery}
USING (column_position)
"""
# Create the output table.
# If there are no groupin cols, the query list will have a single element.
# Therefore, we always execute the 0'th element of the list.
# If there are grouping cols, we loop through the rest of the list and
# execute them one by one.
create_output_table_query = """
CREATE TABLE {output_table} AS
{select_deconstruct_query}
""".format(**locals()).format(
deconstruction_query=deconstruction_query_list[0],
num_cols=len(col_names), **locals())
plpy.execute(create_output_table_query)
if grouping_cols:
for i in range(1, len(deconstruction_query_list)):
insert_to_output_table_query = """
INSERT INTO {output_table}
{select_deconstruct_query}
""".format(**locals()).format(
deconstruction_query=deconstruction_query_list[i],
num_cols=len(col_names), **locals())
plpy.execute(insert_to_output_table_query)
# create summary table
summary_table = add_postfix(output_table, "_summary")
create_summary_table_query = """
CREATE TABLE {summary_table} AS
SELECT
'{function_name}'::varchar AS method,
'{source_table}'::varchar AS source,
'{output_table}'::varchar AS output_table,
{col_names_as_text_array} AS column_names,
{grouping_cols_comma}
{mean_col} AS mean_vector,
{tot_cnt} AS total_rows_processed
FROM {temp_output_table}
""".format(**locals())
plpy.execute(create_summary_table_query)
# clean up and return
plpy.execute("DROP TABLE IF EXISTS {temp_output_table}".format(**locals()))
end = time()
return (output_table, len(col_names), end - start)
# ------------------------------------------------------------------------------
def _create_deconstruction_query(schema_madlib, col_names, grouping_cols,
temp_output_table, cor_mat, n_groups_per_run):
"""
Creates the query to convert the matrix into the lower-traingular format.
Args:
@param schema_madlib Schema of MADlib
@param col_names Name of all columns to place in output table
@param grouping_cols Name of all columns to be used for grouping
@param temp_output_table Name of the temporary table that contains
the matrix to deconstruct
@param cor_mat Name of column that containss the matrix
to deconstruct
@param n_groups_per_run Number of groups to deconstruct in a single
sub-query
Returns:
List of Strings where each string is a SQL sub-query for deconstructing
the matrix. Each sub-query covers n_groups_per_run number of union all
queries
"""
# The matrix that holds the PCC computation must be converted to a
# table capturing all pair wise PCC values. That is done using
# a UDF named __deconstruct_lower_triangle.
# With grouping, calling that UDF becomes a lot more complex, so
# construct the query accordingly.
COL_WIDTH = 10
# split the col_names to equal size sets with newline between to prevent a
# long query. Build a 2d array of the col_names, each inner array with
# COL_WIDTH number of names.
col_names_split = [col_names[x : x + COL_WIDTH]
for x in range(0, len(col_names), COL_WIDTH)]
variable_list_str = ', \n'.join([', '.join(
['{0} float8'.format(col)
for col in cols_blob
]) for cols_blob in col_names_split
])
# Fix for MADLIB-1301.
# Creating a huge chain of union all sub-queries might create problems with
# memory/stack/execution time. We divide them into chunks of 10 (or
# whatever the user decides) and insert these chunks one by one. So create
# a list of these sub-query chunks. If no grouping cols are provided, the
# list will contain only one sub-query that will be executed.
deconstruction_query_list = []
if grouping_cols:
grp_dict_rows = plpy.execute("SELECT {0} FROM {1}".format(
grouping_cols,
temp_output_table))
deconstruction_grp_queries_list = list()
for grp_dict in grp_dict_rows:
where_condition = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v)
for k, v in grp_dict.items())
select_grouping_cols = ' , '.join("'{1}' AS {0}".format(k, v)
for k, v in grp_dict.items())
deconstruction_grp_queries_list.append("""
SELECT {select_grouping_cols}, *
FROM {schema_madlib}.__deconstruct_lower_triangle(
(SELECT {cor_mat} FROM {temp_output_table} {where_condition})
) AS deconstructed(column_position integer, {variable_list_str})
""".format(**locals()))
for i in range(0, len(deconstruction_grp_queries_list), n_groups_per_run):
sublist = deconstruction_grp_queries_list[i:i+n_groups_per_run]
deconstruction_query_list.append(' UNION ALL '.join(sublist))
else:
deconstruction_query = """
SELECT * FROM
{schema_madlib}.__deconstruct_lower_triangle(
(SELECT {cor_mat} FROM {temp_output_table})
) AS deconstructed(column_position integer, {variable_list_str})
""".format(**locals())
deconstruction_query_list = [deconstruction_query]
return deconstruction_query_list
def correlation_help_message(schema_madlib, message, cov=False, **kwargs):
"""
Given a help string, provide usage information
"""
func = "covariance" if cov else "correlation"
if message is not None and \
message.lower() in ("usage", "help", "?"):
return """
Usage:
-----------------------------------------------------------------------
SELECT {schema_madlib}.{func}
(
source_table TEXT, -- Source table name (Required)
output_table TEXT, -- Output table name (Required)
target_cols TEXT, -- Comma separated columns for which summary is desired
-- (Default: '*' - produces result for all columns)
verbose BOOLEAN, -- Verbosity
grouping_cols TEXT, -- Comma separated columns for grouping
n_groups_per_run INTEGER -- number of groups to process at a time
)
-----------------------------------------------------------------------
Output will be a table with N+2 columns and N rows, where N is the number
of numeric columns in 'target_cols'.
The columns of the table are described as follows:
- column_position : Position of the variable in the 'source_table'.
- variable : Provides the row-header for each variable
- Rest of the table is the NxN {func} matrix for all numeric columns
in 'source_table'.
The output table is arranged as a lower-traingular matrix with the upper
triangle set to NULL. To obtain the result from the output_table in this matrix
format ensure to order the elements using the 'column_position' column.
""".format(schema_madlib=schema_madlib, func=func)
else:
if cov:
return """
Covariance is a measure of how much two random variables change together. If the
greater values of one variable mainly correspond with the greater values of the
other variable, and the same holds for the smaller values, i.e., the variables
tend to show similar behavior, the covariance is positive. In the opposite
case, when the greater values of one variable mainly correspond to the smaller
values of the other, i.e., the variables tend to show opposite behavior, the
covariance is negative. The sign of the covariance therefore shows the tendency
-------
For an overview on usage, run:
SELECT {schema_madlib}.covariance('usage');
""".format(schema_madlib=schema_madlib)
else:
return """
A correlation function is the degree and direction of association of
two variables; how well can one random variable be predicted
from the other. The coefficient of correlation varies from -1 to 1:
1 implies perfect correlation, 0 means no correlation, and -1 means
perfectly anti-correlated.
-------
For an overview on usage, run:
SELECT {schema_madlib}.correlation('usage');
""".format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------------