blob: 6d0ebaea9fdd71b8d3fcdebfdabafb59e711c2f0 [file] [log] [blame]
# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Pivoting
# The goal of the MADlib pivot function is to provide a data summarization tool
# that can do basic OLAP type operations on data stored in one table and output
# the summarized data to a second table. Typical operations are count, average,
# min, max and standard deviation, however user defined aggregates (UDAs) are
# also be allowed.
# Please refer to the pivot.sql_in file for the documentation
@file pivot.py_in
import plpy
import itertools
from control import MinWarning
from utilities import _assert
from utilities import split_quoted_delimited_str
from utilities import strip_end_quotes
from utilities import extract_keyvalue_params
from validate_args import table_exists
from validate_args import columns_exist_in_table
from validate_args import table_is_empty
from validate_args import _get_table_schema_names
from validate_args import get_first_schema
from validate_args import get_expr_type
m4_changequote(`<!', `!>')
def pivot(schema_madlib, source_table, out_table, index, pivot_cols,
pivot_values, aggregate_func=None, fill_value=None, keep_null=False,
output_col_dictionary=False, output_type=None, **kwargs):
Helper function that can be used to pivot tables
@param source_table The original data table
@param out_table The output table that contains the dummy
variable columns
@param index The index columns to group the records by
@param pivot_cols The columns to pivot the table
@param pivot_values The value columns to be summarized in the
pivoted table
@param aggregate_func The aggregate function to be applied to the
@param fill_value If specified, determines how to fill NULL
values resulting from pivot operation
@param keep_null The flag for determining how to handle NULL
values in pivot columns
Assume we have the following table
pivset( id INTEGER, piv FLOAT8, val FLOAT8 )
where the piv column has 3 distinct values (10, 20 and 30).
If the pivot function call is :
SELECT madlib.pivot('pivset', 'pivout', 'id', 'piv', 'val');
We want to construct the following sql code to pivot the table.
avg(CASE WHEN "piv" = '10' THEN val ELSE NULL END ) as "val_avg_piv_10",
avg(CASE WHEN "piv" = '20' THEN val ELSE NULL END ) as "val_avg_piv_20",
avg(CASE WHEN "piv" = '30' THEN val ELSE NULL END ) as "val_avg_piv_30"
FROM pivset GROUP BY id ORDER BY id)
def _fill_value_wrapper(sel_str):
""" Wrap a given SQL SELECT statement with COALESCE using a given fill value.
No-op if the fill value is not provided
if fill_value is not None:
return " COALESCE({0}, {1}) ".format(sel_str, fill_value)
return sel_str
with MinWarning('warning'):
# If there are more than 1000 columns for the output table, we give a
# warning as it might give an error.
# If a column name has more than 63 characters it gets trimmed automatically,
# which may cause an exception. Enable the output dictionary in this case.
indices = split_quoted_delimited_str(index)
pcols = split_quoted_delimited_str(pivot_cols)
pvals = split_quoted_delimited_str(pivot_values)
# output type for specific supported types
output_type = 'column' if not output_type else output_type.lower()
all_output_types = sorted(['array', 'column', 'svec'])
# allow user to specify a prefix substring of
# supported output types. This works because the supported
# output types have unique prefixes.
output_type = next(s for s in all_output_types
if s.startswith(output_type))
except StopIteration:
# next() returns a StopIteration if no element found
plpy.error("Encoding categorical: Output type should be one of {0}".
is_array_output = output_type in ('array', 'svec')
# always build dictionary table if output is array
output_col_dictionary = True if is_array_output else output_col_dictionary
validate_pivot_coding(source_table, out_table, indices, pcols, pvals)
# Strip the end quotes for building output columns (this can only be
# performed after the validation)
pcols = [strip_end_quotes(pcol.strip()) for pcol in pcols]
pvals = [strip_end_quotes(pval.strip()) for pval in pvals]
# Create a dictionary that assigns one or more aggregate functions for every
# value column.
agg_dict = parse_aggregates(pvals, aggregate_func)
validate_output_types(source_table, agg_dict, is_array_output)
# Find the distinct values of pivot_cols
array_agg_str = ', '.join("array_agg(DISTINCT {pcol}) AS {pcol}_values".
format(pcol=pcol) for pcol in pcols)
if keep_null:
# Some platforms don't include NULL values as part of the array_agg(DISTINCT ...)
# Below clause checks explicitly for NULL values
null_str = ", " + ', '.join(
"bool_or(CASE WHEN {pcol} IS NULL THEN True ELSE False END)"
"AS {pcol}_isnull".format(pcol=pcol) for pcol in pcols)
null_str = ""
distinct_values = plpy.execute("SELECT {0} {1} FROM {2}".
format(array_agg_str, null_str, source_table))[0]
# Collect the distinct values for every pivot column into a dictionary
pcol_distinct_values = {}
pcol_max_length = 0
for pcol in pcols:
pcol_tmp = set(item for item in distinct_values[pcol + "_values"])
if not keep_null:
elif distinct_values[pcol + "_isnull"]:
pcol_distinct_values[pcol] = sorted(pcol_tmp)
# Max pcol length calculation: the name of column (pcol) +
# name of longest value in column (item) +
# underscore (1)
pcol_max_length += (len(pcol) +
max([len(str(item)) for item in pcol_tmp]) +
# Create the combination of every possible pivot column
# Assume piv and piv2 are pivot columns. piv=(1,2) and piv2=(3,4,5)
# pivot_comb = ((1,3),(1,4),(1,5),(2,3),(2,4),(2,5))
pivot_comb = list(itertools.product(*([pcol_distinct_values[pcol]
for pcol in pcols])))
# Check the max possible length of a output column name
# If it is over 63 (postgresql upper limit) create dictionary lookup
for pval in pvals:
agg_func = agg_dict[pval]
# Length calculation: value column length + aggregate length +
# 2 underscores + pivots and their values (pcol_max_length)
# Example: val _ sum _ piv1_10_piv2_100
col_name_len = (2 + len(pval) + pcol_max_length +
max([len(item) for item in agg_func]))
if col_name_len > MAX_COLUMN_LENGTH:
with MinWarning("warning"):
plpy.warning("Pivot: Output columns are renamed to keep them "
"under 63 characters. Please refer to "
"{source_table}_dictionary for the original names.".
output_col_dictionary = True
# Types of pivot columns are needed for building the right columns
# in the dictionary table and to decide if a pivot column value needs to
# be quoted during comparison (will be quoted if it's a text column)
types_str = ', '.join("pg_typeof(\"{pcol}\") as {pcol}".
format(pcol=p) for p in pcols)
pcol_types = plpy.execute("SELECT {0} FROM {1} LIMIT 1".
format(types_str, source_table))[0]
if output_col_dictionary:
out_dict = out_table + "_dictionary"
_assert(not table_exists(out_dict),
"Pivot: Output dictionary table already exists!")
# Create the empty dictionary table
pcol_names_types = ', '.join(" {pcol} {pcol_type} ".
for pcol in pcols)
CREATE TABLE {out_dict} (
__pivot_cid__ VARCHAR,
col_name VARCHAR)
""".format(out_dict=out_dict, pcol_names_types=pcol_names_types))
# List of rows to insert into output dictionary
dict_insert_str = []
# Counter for the new output column names
dict_counter = 1
pivot_sel_list = []
pivot_from_list = []
for pval in pvals:
agg_func = agg_dict[pval]
for agg in agg_func:
# is using array_output, create a new array for each pval-agg combo
if is_array_output:
# we store information in the dictionary table for each
# index in the array. 'index_counter' is the current index
# being updated (resets for each new array)
index_counter = 1
sub_pivot_sel_list = []
for comb in pivot_comb:
pivot_col_condition = []
# note column name starts with double quotes
pivot_col_name = ['{pval}_{agg}'.format(pval=pval, agg=agg)]
if output_col_dictionary:
# Prepare the entry for the dictionary
if not is_array_output:
index_name = ("__p_{dict_counter}__".
# for arrays, index_name is just the index into each array
index_name = str(index_counter)
index_counter += 1
"(\'{index_name}\', \'{pval}\', \'{agg}\' ".
format(index_name=index_name, pval=pval, agg=agg))
# For every pivot column in a given combination
for counter, pcol in enumerate(pcols):
if comb[counter] is None:
quoted_pcol_value = "NULL"
elif pcol_types[pcol] in ("text", "varchar", "character varying"):
quoted_pcol_value = "'" + comb[counter] + "'"
quoted_pcol_value = comb[counter]
# If we encounter a NULL value that means it is not filtered
# because of keep_null. Use "IS NULL" for comparison
if comb[counter] is None:
pivot_col_condition.append(" \"{0}\" IS NULL".format(pcol))
pivot_col_condition.append(" \"{0}\" = {1}".
format(pcol, quoted_pcol_value))
pivot_col_name.append("_{0}_{1}".format(pcol, comb[counter]))
if output_col_dictionary:
if output_col_dictionary:
# Store the whole string as additional info
pivot_col_name = ["__p_" + str(dict_counter) + "__"]
dict_counter += 1
# Collecting the whole sql query
# Please refer to the earlier comment for a sample output
# Build the pivot column with NULL values in tuples that don't
# satisfy that column's condition
p_name = '"{0}"'.format(''.join(pivot_col_name))
pivot_str_from = (
"(CASE WHEN {condition} THEN {pval} END) AS {p_name}".
condition=' AND '.join(pivot_col_condition),
# Aggregate over each pivot column, while filtering all NULL
# values created by previous query.
sub_pivot_str_sel = _fill_value_wrapper(
"{agg}({p_name}) "
format(agg=agg, p_name=p_name))
if not is_array_output:
# keep spaces around the 'AS'
sub_pivot_str_sel += " AS " + p_name
if sub_pivot_sel_list:
if is_array_output:
if output_type is 'svec':
cast_str = '::FLOAT8[]::{0}.svec'.format(schema_madlib)
cast_str = '::FLOAT8[]'
'ARRAY[{all_pivot_sel}]{cast_str} AS "{pval}_{agg}"'.
format(all_pivot_sel=', '.join(sub_pivot_sel_list),
pivot_sel_list += sub_pivot_sel_list
CREATE TABLE {out_table} AS
SELECT {index},
SELECT {index},
FROM {source_table}
) x
GROUP BY {index}
all_pivot_from_str=', '.join(pivot_from_list),
all_pivot_sel_str=', '.join(pivot_sel_list)
if output_col_dictionary:
plpy.execute("INSERT INTO {out_dict} VALUES {insert_sql}".
insert_sql=', '.join(dict_insert_str)))
except plpy.SPIError:
# Warn user if the number of columns is over the limit
with MinWarning("warning"):
# The column options from value columns and aggregates
# times the number of pivot combinations
if ((sum([len(item) for item in agg_dict.values()])*
len(pivot_comb)) > MAX_OUTPUT_COLUMN_COUNT):
"Pivot: Too many distinct values for pivoting! "
"The execution may fail due to too many columns in the "
"output table.")
"Pivot: Pivoting is only supported over aggregates with "
"transition functions defined as STRICT.")
return None
# ------------------------------------------------------------------------------
def parse_aggregates(pvals, aggregate_func):
Helper function that parses the aggregate function parameter
@param pvals The value columns to be summarized in the
pivoted table
@param aggregate_func The aggregate function to be applied to the
The aggregate_func can get one of the following forms
1) NULL: Use the default aggregate ('avg')
2) A single aggregate (eg. 'sum')
3) A comma-separated list of aggregates (eg. 'sum,avg')
4) A complete mapping (eg. 'val=sum, val2=[avg,sum]')
5) A partial mapping (eg. 'val2=sum'): Use the default ('avg') for the
missing value columns
param_types = dict.fromkeys(pvals, tuple)
agg_dict = extract_keyvalue_params(aggregate_func, param_types)
if not agg_dict:
agg_list = tuple(split_quoted_delimited_str(aggregate_func))
agg_dict = dict.fromkeys(pvals, (agg_list if agg_list else ('avg', )))
for pval in pvals:
if pval not in agg_dict:
agg_dict[pval] = ('avg', )
return agg_dict
# ------------------------------------------------------------------------------
def validate_pivot_coding(source_table, out_table, indices, pivs, vals):
@param source_table The original data table
@param out_table The output table that will contain dummy columns
@param indices An array of index column names
@param cols An array of categorical column names
_assert(out_table and out_table.strip().lower() not in ('null', ''),
"Pivot: Invalid output table name!")
_assert(not table_exists(out_table),
"Pivot: Output table already exists!")
_assert(source_table and source_table.strip().lower() not in ('null', ''),
"Pivot: Invalid data table name!")
"Pivot: Data table ({0}) is missing!". format(source_table))
_assert(not table_is_empty(source_table),
"Pivot: Data table ({0}) is empty!". format(source_table))
_assert(indices and indices not in ('null', ''), "Pivot: Invalid index column!")
_assert(pivs and pivs not in ('null', ''), "Pivot: Invalid pivot column!")
_assert(vals and vals not in ('null', ''), "Pivot: Invalid value column!")
_assert(columns_exist_in_table(source_table, indices),
"Pivot: Not all columns from {0} present in source table ({1})"
.format(indices, source_table))
_assert(columns_exist_in_table(source_table, pivs),
"Pivot: Not all columns from {0} present in source table ({1})"
.format(pivs, source_table))
_assert(columns_exist_in_table(source_table, vals),
"Pivot: Not all columns from {0} present in source table ({1})"
.format(vals, source_table))
# ------------------------------------------------------------------------------
def validate_output_types(source_table, agg_dict, is_array_output):
@param source_table: str, Name of table containing data
@param agg_dict: dict, Key-value pair containing aggregates applied for each val column
@param is_array_output: bool, Is the pivot output columnar (False) or array (True)
for val, func_iterable in agg_dict.items():
for func in func_iterable:
func_call_str = '{0}({1})'.format(func, val)
_assert(not ('[]' in get_expr_type(func_call_str, source_table) and
"Pivot: Aggregate {0} with an array return type cannot be "
"combined with output_type='array' or 'svec'".format(func))
# ----------------------------------------------------------------------
def pivot_help(schema_madlib, message, **kwargs):
Help function for pivot
@param schema_madlib
@param message: string, Help message string
@param kwargs
String. Help/usage information
if not message:
help_string = """
Provide a data summarization tool that can do basic OLAP type operations on
data stored in one table and output the summarized data to a second table.
Typical operations are count, average, min, max and standard deviation, however
user defined aggregates (UDAs) are also be allowed.
For more details on function usage:
SELECT {schema_madlib}.pivot('usage')
elif message in ['usage', 'help', '?']:
help_string = """
SELECT {schema_madlib}.pivot(
source_table, -- Name of source table containing data for pivoting
out_table, -- Name of output table taht contains pivoted data
index, -- Comma-separated columns that will form the index
-- of the output pivot table
pivot_cols, -- Comma-separated columns that will form the
-- columns of the output pivot table
pivot_values, -- Comma-separated columns that contain the values
-- to be summarized in the output pivot table
fill_value, -- If specified, determines how to fill NULL values
-- resulting from pivot operation
keep_null, -- The flag for determining how to handle NULL
-- values in pivot columns
output_col_dictionary, -- The flag for enabling the creation of the
-- output dictionary for shorter column names
output_type -- This parameter controls the output format
-- of the pivoted variables.
-- If 'column', a column is created for each pivot
-- If 'array', an array is created combining all pivots
-- If 'svec', the array is cast to madlib.svec
The output table ('output_table' above) has all the columns present in index
column list, plus additional columns for each distinct value in pivot_cols.
The column name for the pivot is
set as '<pivot name>_<pivot value>'.
A dictionary table ('<output_table>_dictionary') is created if either
'output_col_dictionary' is True or if the auto-generated column names exceed
the PostgreSQL limit of 63 bytes . This table gives a mapping between short
column names in 'output_table' and the meaning of those columns
i.e. which index, value, agg and pivot column they belong to.
help_string = "No such option. Use {schema_madlib}.pivot()"
return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------