| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # Pivoting |
| # The goal of the MADlib pivot function is to provide a data summarization tool |
| # that can do basic OLAP type operations on data stored in one table and output |
| # the summarized data to a second table. Typical operations are count, average, |
| # min, max and standard deviation, however user defined aggregates (UDAs) are |
| # also be allowed. |
| |
| # Please refer to the pivot.sql_in file for the documentation |
| |
| """ |
| @file pivot.py_in |
| |
| """ |
| import plpy |
| import itertools |
| from control import MinWarning |
| from utilities import _assert |
| from utilities import split_quoted_delimited_str |
| from utilities import strip_end_quotes |
| from utilities import extract_keyvalue_params |
| from validate_args import table_exists |
| from validate_args import columns_exist_in_table |
| from validate_args import table_is_empty |
| from validate_args import _get_table_schema_names |
| from validate_args import get_first_schema |
| from validate_args import get_expr_type |
| |
| |
| m4_changequote(`<!', `!>') |
| |
| |
| def pivot(schema_madlib, source_table, output_table, index, pivot_cols, |
| pivot_values, aggregate_func=None, fill_value=None, keep_null=False, |
| output_col_dictionary=False, output_type=None, **kwargs): |
| """ |
| Helper function that can be used to pivot tables |
| Args: |
| @param source_table The original data table |
| @param output_table The output table that contains the dummy |
| variable columns |
| @param index The index columns to group the records by |
| @param pivot_cols The columns to pivot the table |
| @param pivot_values The value columns to be summarized in the |
| pivoted table |
| @param aggregate_func The aggregate function to be applied to the |
| values |
| @param fill_value If specified, determines how to fill NULL |
| values resulting from pivot operation |
| @param keep_null The flag for determining how to handle NULL |
| values in pivot columns |
| |
| Assume we have the following table |
| pivset( id INTEGER, piv FLOAT8, val FLOAT8 ) |
| where the piv column has 3 distinct values (10, 20 and 30). |
| If the pivot function call is : |
| SELECT madlib.pivot('pivset', 'pivout', 'id', 'piv', 'val'); |
| We want to construct the following sql code to pivot the table. |
| CREATE TABLE pivout AS (SELECT id, |
| avg(CASE WHEN "piv" = '10' THEN val ELSE NULL END ) as "val_avg_piv_10", |
| avg(CASE WHEN "piv" = '20' THEN val ELSE NULL END ) as "val_avg_piv_20", |
| avg(CASE WHEN "piv" = '30' THEN val ELSE NULL END ) as "val_avg_piv_30" |
| FROM pivset GROUP BY id ORDER BY id) |
| """ |
| |
| def _fill_value_wrapper(sel_str): |
| """ Wrap a given SQL SELECT statement with COALESCE using a given fill value. |
| |
| No-op if the fill value is not provided |
| """ |
| if fill_value is not None: |
| return " COALESCE({0}, {1}) ".format(sel_str, fill_value) |
| else: |
| return sel_str |
| |
| with MinWarning('warning'): |
| |
| # If there are more than 1000 columns for the output table, we give a |
| # warning as it might give an error. |
| MAX_OUTPUT_COLUMN_COUNT = 1000 |
| |
| # If a column name has more than 63 characters it gets trimmed automatically, |
| # which may cause an exception. Enable the output dictionary in this case. |
| MAX_COLUMN_LENGTH = 63 |
| |
| indices = split_quoted_delimited_str(index) |
| pcols = split_quoted_delimited_str(pivot_cols) |
| pvals = split_quoted_delimited_str(pivot_values) |
| |
| # output type for specific supported types |
| output_type = 'column' if not output_type else output_type.lower() |
| all_output_types = sorted(['array', 'column', 'svec']) |
| try: |
| # allow user to specify a prefix substring of |
| # supported output types. This works because the supported |
| # output types have unique prefixes. |
| output_type = next(s for s in all_output_types |
| if s.startswith(output_type)) |
| except StopIteration: |
| # next() returns a StopIteration if no element found |
| plpy.error("Encoding categorical: Output type should be one of {0}". |
| format(','.join(all_output_types))) |
| |
| is_array_output = output_type in ('array', 'svec') |
| # always build dictionary table if output is array |
| output_col_dictionary = True if is_array_output else output_col_dictionary |
| |
| validate_pivot_coding(source_table, output_table, indices, pcols, pvals) |
| |
| # Strip the end quotes for building output columns (this can only be |
| # performed after the validation) |
| pcols = [strip_end_quotes(pcol.strip()) for pcol in pcols] |
| pvals = [strip_end_quotes(pval.strip()) for pval in pvals] |
| |
| # Create a dictionary that assigns one or more aggregate functions for every |
| # value column. |
| agg_dict = parse_aggregates(pvals, aggregate_func) |
| |
| validate_output_types(source_table, agg_dict, is_array_output) |
| distinct_values = {} |
| for pcol in pcols: |
| # Find the distinct values of pivot_cols |
| distinct_pcol_values_sql = "(SELECT DISTINCT {pcol} AS {pcol}_values FROM {source_table})tmp_x".format(**locals()) |
| |
| array_agg_str = "array_agg({pcol}_values) AS {pcol}_values".format(pcol=pcol) |
| pcol_name = "{pcol}_values".format(**locals()) |
| distinct_values[pcol_name] = ( |
| plpy.execute(""" SELECT {array_agg_str} |
| FROM {distinct_pcol_values_sql} |
| """.format(**locals())))[0][pcol_name] |
| if keep_null: |
| # Some platforms don't include NULL values as part of the array_agg(DISTINCT ...) |
| # Below clause checks explicitly for NULL values |
| pcol_null_name = "{pcol}_isnull".format(**locals()) |
| distinct_values[pcol_null_name] = ( |
| plpy.execute(""" |
| SELECT bool_or(CASE WHEN {pcol} IS NULL THEN True |
| ELSE False END) |
| AS {pcol}_isnull |
| FROM {source_table} |
| """.format(**locals())))[0][pcol_null_name] |
| |
| # Collect the distinct values for every pivot column into a dictionary |
| pcol_distinct_values = {} |
| pcol_max_length = 0 |
| for pcol in pcols: |
| pcol_tmp = set(item for item in distinct_values[pcol + "_values"]) |
| if not keep_null: |
| pcol_tmp.discard(None) |
| elif distinct_values[pcol + "_isnull"]: |
| pcol_tmp.add(None) |
| |
| pcol_distinct_values[pcol] = sorted(pcol_tmp) |
| # Max pcol length calculation: the name of column (pcol) + |
| # name of longest value in column (item) + |
| # underscore (1) |
| pcol_max_length += (len(pcol) + |
| max([len(str(item)) for item in pcol_tmp]) + |
| 1) |
| |
| # Create the combination of every possible pivot column |
| # Assume piv and piv2 are pivot columns. piv=(1,2) and piv2=(3,4,5) |
| # pivot_comb = ((1,3),(1,4),(1,5),(2,3),(2,4),(2,5)) |
| pivot_comb = list(itertools.product(*([pcol_distinct_values[pcol] |
| for pcol in pcols]))) |
| |
| # Check the max possible length of a output column name |
| # If it is over 63 (postgresql upper limit) create dictionary lookup |
| for pval in pvals: |
| agg_func = agg_dict[pval] |
| # Length calculation: value column length + aggregate length + |
| # 2 underscores + pivots and their values (pcol_max_length) |
| # Example: val _ sum _ piv1_10_piv2_100 |
| col_name_len = (2 + len(pval) + pcol_max_length + |
| max([len(item) for item in agg_func])) |
| if col_name_len > MAX_COLUMN_LENGTH: |
| with MinWarning("warning"): |
| plpy.warning("Pivot: Output columns are renamed to keep them " |
| "under 63 characters. Please refer to " |
| "{source_table}_dictionary for the original names.". |
| format(**locals())) |
| output_col_dictionary = True |
| |
| # Types of pivot columns are needed for building the right columns |
| # in the dictionary table and to decide if a pivot column value needs to |
| # be quoted during comparison (will be quoted if it's a text column) |
| types_str = ', '.join("pg_typeof(\"{pcol}\") as {pcol}". |
| format(pcol=p) for p in pcols) |
| pcol_types = plpy.execute("SELECT {0} FROM {1} LIMIT 1". |
| format(types_str, source_table))[0] |
| if output_col_dictionary: |
| out_dict = output_table + "_dictionary" |
| _assert(not table_exists(out_dict), |
| "Pivot: Output dictionary table already exists!") |
| # Create the empty dictionary table |
| pcol_names_types = ', '.join(" {pcol} {pcol_type} ". |
| format(pcol=pcol, |
| pcol_type=pcol_types[pcol]) |
| for pcol in pcols) |
| plpy.execute(""" |
| CREATE TABLE {out_dict} ( |
| __pivot_cid__ VARCHAR, |
| pval VARCHAR, |
| agg VARCHAR, |
| {pcol_names_types}, |
| col_name VARCHAR) |
| """.format(out_dict=out_dict, pcol_names_types=pcol_names_types)) |
| |
| # List of rows to insert into output dictionary |
| dict_insert_str = [] |
| # Counter for the new output column names |
| dict_counter = 1 |
| |
| pivot_sel_list = [] |
| pivot_from_list = [] |
| |
| for pval in pvals: |
| agg_func = agg_dict[pval] |
| for agg in agg_func: |
| |
| # is using array_output, create a new array for each pval-agg combo |
| if is_array_output: |
| # we store information in the dictionary table for each |
| # index in the array. 'index_counter' is the current index |
| # being updated (resets for each new array) |
| index_counter = 1 |
| |
| sub_pivot_sel_list = [] |
| for comb in pivot_comb: |
| pivot_col_condition = [] |
| # note column name starts with double quotes |
| pivot_col_name = ['{pval}_{agg}'.format(pval=pval, agg=agg)] |
| |
| if output_col_dictionary: |
| # Prepare the entry for the dictionary |
| if not is_array_output: |
| index_name = ("__p_{dict_counter}__". |
| format(dict_counter=dict_counter)) |
| else: |
| # for arrays, index_name is just the index into each array |
| index_name = str(index_counter) |
| index_counter += 1 |
| dict_insert_str.append( |
| "(\'{index_name}\', \'{pval}\', \'{agg}\' ". |
| format(index_name=index_name, pval=pval, agg=agg)) |
| |
| # For every pivot column in a given combination |
| for counter, pcol in enumerate(pcols): |
| if comb[counter] is None: |
| quoted_pcol_value = "NULL" |
| elif pcol_types[pcol] in ("text", "varchar", "character varying"): |
| quoted_pcol_value = "'" + comb[counter] + "'" |
| else: |
| quoted_pcol_value = comb[counter] |
| |
| # If we encounter a NULL value that means it is not filtered |
| # because of keep_null. Use "IS NULL" for comparison |
| if comb[counter] is None: |
| pivot_col_condition.append(" \"{0}\" IS NULL".format(pcol)) |
| pivot_col_name.append("_{0}_null".format(pcol)) |
| else: |
| pivot_col_condition.append(" \"{0}\" = {1}". |
| format(pcol, quoted_pcol_value)) |
| pivot_col_name.append("_{0}_{1}".format(pcol, comb[counter])) |
| |
| if output_col_dictionary: |
| dict_insert_str.append("{0}".format(quoted_pcol_value)) |
| |
| if output_col_dictionary: |
| # Store the whole string as additional info |
| dict_insert_str.append("'{0}')".format(''.join(pivot_col_name))) |
| pivot_col_name = ["__p_" + str(dict_counter) + "__"] |
| dict_counter += 1 |
| |
| # Collecting the whole sql query |
| # Please refer to the earlier comment for a sample output |
| # Build the pivot column with NULL values in tuples that don't |
| # satisfy that column's condition |
| p_name = '"{0}"'.format(''.join(pivot_col_name)) |
| pivot_str_from = ( |
| "(CASE WHEN {condition} THEN {pval} END) AS {p_name}". |
| format(pval=pval, |
| condition=' AND '.join(pivot_col_condition), |
| p_name=p_name)) |
| pivot_from_list.append(pivot_str_from) |
| |
| # Aggregate over each pivot column, while filtering all NULL |
| # values created by previous query. |
| sub_pivot_str_sel = _fill_value_wrapper( |
| "{agg}({p_name}) " |
| " FILTER (WHERE {p_name} IS NOT NULL)". |
| format(agg=agg, p_name=p_name)) |
| if not is_array_output: |
| # keep spaces around the 'AS' |
| sub_pivot_str_sel += " AS " + p_name |
| sub_pivot_sel_list.append(sub_pivot_str_sel) |
| |
| if sub_pivot_sel_list: |
| if is_array_output: |
| if output_type is 'svec': |
| cast_str = '::FLOAT8[]::{0}.svec'.format(schema_madlib) |
| else: |
| cast_str = '::FLOAT8[]' |
| pivot_sel_list.append( |
| 'ARRAY[{all_pivot_sel}]{cast_str} AS "{pval}_{agg}"'. |
| format(all_pivot_sel=', '.join(sub_pivot_sel_list), |
| cast_str=cast_str, |
| pval=pval, |
| agg=agg)) |
| else: |
| pivot_sel_list += sub_pivot_sel_list |
| |
| try: |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT {index}, |
| {all_pivot_sel_str} |
| FROM ( |
| SELECT {index}, |
| {all_pivot_from_str} |
| FROM {source_table} |
| ) x |
| GROUP BY {index} |
| """.format(output_table=output_table, |
| index=index, |
| source_table=source_table, |
| all_pivot_from_str=', '.join(pivot_from_list), |
| all_pivot_sel_str=', '.join(pivot_sel_list) |
| )) |
| |
| if output_col_dictionary: |
| plpy.execute("INSERT INTO {out_dict} VALUES {insert_sql}". |
| format(out_dict=out_dict, |
| insert_sql=', '.join(dict_insert_str))) |
| except plpy.SPIError: |
| # Warn user if the number of columns is over the limit |
| with MinWarning("warning"): |
| # The column options from value columns and aggregates |
| # times the number of pivot combinations |
| if ((sum([len(item) for item in agg_dict.values()])* |
| len(pivot_comb)) > MAX_OUTPUT_COLUMN_COUNT): |
| plpy.warning( |
| "Pivot: Too many distinct values for pivoting! " |
| "The execution may fail due to too many columns in the " |
| "output table.") |
| else: |
| plpy.warning( |
| "Pivot: Pivoting is only supported over aggregates with " |
| "transition functions defined as STRICT.") |
| raise |
| |
| return None |
| # ------------------------------------------------------------------------------ |
| |
| |
| def parse_aggregates(pvals, aggregate_func): |
| """ |
| Helper function that parses the aggregate function parameter |
| Args: |
| @param pvals The value columns to be summarized in the |
| pivoted table |
| @param aggregate_func The aggregate function to be applied to the |
| values |
| """ |
| """ |
| The aggregate_func can get one of the following forms |
| 1) NULL: Use the default aggregate ('avg') |
| 2) A single aggregate (eg. 'sum') |
| 3) A comma-separated list of aggregates (eg. 'sum,avg') |
| 4) A complete mapping (eg. 'val=sum, val2=[avg,sum]') |
| 5) A partial mapping (eg. 'val2=sum'): Use the default ('avg') for the |
| missing value columns |
| """ |
| param_types = dict.fromkeys(pvals, tuple) |
| agg_dict = extract_keyvalue_params(aggregate_func, param_types) |
| |
| if not agg_dict: |
| agg_list = tuple(split_quoted_delimited_str(aggregate_func)) |
| agg_dict = dict.fromkeys(pvals, (agg_list if agg_list else ('avg', ))) |
| else: |
| for pval in pvals: |
| if pval not in agg_dict: |
| agg_dict[pval] = ('avg', ) |
| return agg_dict |
| # ------------------------------------------------------------------------------ |
| |
| |
| def validate_pivot_coding(source_table, output_table, indices, pivs, vals): |
| |
| """ |
| Args: |
| @param source_table The original data table |
| @param output_table The output table that will contain dummy columns |
| @param indices An array of index column names |
| @param cols An array of categorical column names |
| """ |
| _assert(output_table and output_table.strip().lower() not in ('null', ''), |
| "Pivot: Invalid output table name!") |
| _assert(not table_exists(output_table), |
| "Pivot: Output table already exists!") |
| _assert(source_table and source_table.strip().lower() not in ('null', ''), |
| "Pivot: Invalid data table name!") |
| _assert(table_exists(source_table), |
| "Pivot: Data table ({0}) is missing!". format(source_table)) |
| _assert(not table_is_empty(source_table), |
| "Pivot: Data table ({0}) is empty!". format(source_table)) |
| |
| _assert(indices and indices not in ('null', ''), "Pivot: Invalid index column!") |
| _assert(pivs and pivs not in ('null', ''), "Pivot: Invalid pivot column!") |
| _assert(vals and vals not in ('null', ''), "Pivot: Invalid value column!") |
| |
| _assert(columns_exist_in_table(source_table, indices), |
| "Pivot: Not all columns from {0} present in source table ({1})" |
| .format(indices, source_table)) |
| _assert(columns_exist_in_table(source_table, pivs), |
| "Pivot: Not all columns from {0} present in source table ({1})" |
| .format(pivs, source_table)) |
| _assert(columns_exist_in_table(source_table, vals), |
| "Pivot: Not all columns from {0} present in source table ({1})" |
| .format(vals, source_table)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def validate_output_types(source_table, agg_dict, is_array_output): |
| """ |
| Args: |
| @param source_table: str, Name of table containing data |
| @param agg_dict: dict, Key-value pair containing aggregates applied for each val column |
| @param is_array_output: bool, Is the pivot output columnar (False) or array (True) |
| |
| Returns: |
| None |
| """ |
| for val, func_iterable in agg_dict.items(): |
| for func in func_iterable: |
| func_call_str = '{0}({1})'.format(func, val) |
| _assert(not ('[]' in get_expr_type(func_call_str, source_table) and |
| is_array_output), |
| "Pivot: Aggregate {0} with an array return type cannot be " |
| "combined with output_type='array' or 'svec'".format(func)) |
| # ---------------------------------------------------------------------- |
| |
| |
| def pivot_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for pivot |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if not message: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| SUMMARY |
| ----------------------------------------------------------------------- |
| Provide a data summarization tool that can do basic OLAP type operations on |
| data stored in one table and output the summarized data to a second table. |
| Typical operations are count, average, min, max and standard deviation, however |
| user defined aggregates (UDAs) are also be allowed. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.pivot('usage') |
| """ |
| elif message in ['usage', 'help', '?']: |
| help_string = """ |
| ----------------------------------------------------------------------- |
| USAGE |
| ----------------------------------------------------------------------- |
| SELECT {schema_madlib}.pivot( |
| source_table, -- Name of source table containing data for pivoting |
| output_table, -- Name of output table taht contains pivoted data |
| index, -- Comma-separated columns that will form the index |
| -- of the output pivot table |
| pivot_cols, -- Comma-separated columns that will form the |
| -- columns of the output pivot table |
| pivot_values, -- Comma-separated columns that contain the values |
| -- to be summarized in the output pivot table |
| fill_value, -- If specified, determines how to fill NULL values |
| -- resulting from pivot operation |
| keep_null, -- The flag for determining how to handle NULL |
| -- values in pivot columns |
| output_col_dictionary, -- The flag for enabling the creation of the |
| -- output dictionary for shorter column names |
| output_type -- This parameter controls the output format |
| -- of the pivoted variables. |
| -- If 'column', a column is created for each pivot |
| -- If 'array', an array is created combining all pivots |
| -- If 'svec', the array is cast to madlib.svec |
| ); |
| |
| ----------------------------------------------------------------------- |
| OUTPUT |
| ----------------------------------------------------------------------- |
| The output table ('output_table' above) has all the columns present in index |
| column list, plus additional columns for each distinct value in pivot_cols. |
| The column name for the pivot is |
| set as '<pivot name>_<pivot value>'. |
| |
| A dictionary table ('<output_table>_dictionary') is created if either |
| 'output_col_dictionary' is True or if the auto-generated column names exceed |
| the PostgreSQL limit of 63 bytes . This table gives a mapping between short |
| column names in 'output_table' and the meaning of those columns |
| i.e. which index, value, agg and pivot column they belong to. |
| """ |
| else: |
| help_string = "No such option. Use {schema_madlib}.pivot()" |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| # --------------------------------------------------------------------- |