| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| Functions to reproduce the post-processing of data on text charts. |
| |
| Some text-based charts (pivot tables and t-test table) perform |
| post-processing of the data in Javascript. When sending the data |
| to users in reports we want to show the same data they would see |
| on Explore. |
| |
| In order to do that, we reproduce the post-processing in Python |
| for these chart types. |
| """ |
| |
| from io import StringIO |
| from typing import Any, Dict, List, Optional, Tuple |
| |
| import pandas as pd |
| |
| from superset.utils.core import ( |
| ChartDataResultFormat, |
| DTTM_ALIAS, |
| extract_dataframe_dtypes, |
| get_metric_name, |
| ) |
| |
| |
| def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: |
| """ |
| Sort columns when combining metrics. |
| |
| MultiIndex labels have the metric name as the last element in the |
| tuple. We want to sort these according to the list of passed metrics. |
| """ |
| parts: List[Any] = list(label) |
| metric = parts[-1] |
| parts[-1] = metrics.index(metric) |
| return tuple(parts) |
| |
| |
| def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches |
| df: pd.DataFrame, |
| rows: List[str], |
| columns: List[str], |
| metrics: List[str], |
| aggfunc: str = "Sum", |
| transpose_pivot: bool = False, |
| combine_metrics: bool = False, |
| show_rows_total: bool = False, |
| show_columns_total: bool = False, |
| apply_metrics_on_rows: bool = False, |
| ) -> pd.DataFrame: |
| metric_name = f"Total ({aggfunc})" |
| |
| if transpose_pivot: |
| rows, columns = columns, rows |
| |
| # to apply the metrics on the rows we pivot the dataframe, apply the |
| # metrics to the columns, and pivot the dataframe back before |
| # returning it |
| if apply_metrics_on_rows: |
| rows, columns = columns, rows |
| axis = {"columns": 0, "rows": 1} |
| else: |
| axis = {"columns": 1, "rows": 0} |
| |
| # pivot data; we'll compute totals and subtotals later |
| if rows or columns: |
| df = df.pivot_table( |
| index=rows, |
| columns=columns, |
| values=metrics, |
| aggfunc=pivot_v2_aggfunc_map[aggfunc], |
| margins=False, |
| ) |
| else: |
| # if there's no rows nor columns we have a single value; update |
| # the index with the metric name so it shows up in the table |
| df.index = pd.Index([*df.index[:-1], metric_name], name="metric") |
| |
| # if no rows were passed the metrics will be in the rows, so we |
| # need to move them back to columns |
| if columns and not rows: |
| df = df.stack().to_frame().T |
| df = df[metrics] |
| df.index = pd.Index([*df.index[:-1], metric_name], name="metric") |
| |
| # combining metrics changes the column hierarchy, moving the metric |
| # from the top to the bottom, eg: |
| # |
| # ('SUM(col)', 'age', 'name') => ('age', 'name', 'SUM(col)') |
| if combine_metrics and isinstance(df.columns, pd.MultiIndex): |
| # move metrics to the lowest level |
| new_order = [*range(1, df.columns.nlevels), 0] |
| df = df.reorder_levels(new_order, axis=1) |
| |
| # sort columns, combining metrics for each group |
| decorated_columns = [(col, i) for i, col in enumerate(df.columns)] |
| grouped_columns = sorted( |
| decorated_columns, key=lambda t: get_column_key(t[0], metrics) |
| ) |
| indexes = [i for col, i in grouped_columns] |
| df = df[df.columns[indexes]] |
| elif rows: |
| # if metrics were not combined we sort the dataframe by the list |
| # of metrics defined by the user |
| df = df[metrics] |
| |
| # compute fractions, if needed |
| if aggfunc.endswith(" as Fraction of Total"): |
| total = df.sum().sum() |
| df = df.astype(total.dtypes) / total |
| elif aggfunc.endswith(" as Fraction of Columns"): |
| total = df.sum(axis=axis["rows"]) |
| df = df.astype(total.dtypes).div(total, axis=axis["columns"]) |
| elif aggfunc.endswith(" as Fraction of Rows"): |
| total = df.sum(axis=axis["columns"]) |
| df = df.astype(total.dtypes).div(total, axis=axis["rows"]) |
| |
| # convert to a MultiIndex to simplify logic |
| if not isinstance(df.index, pd.MultiIndex): |
| df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index]) |
| if not isinstance(df.columns, pd.MultiIndex): |
| df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns]) |
| |
| if show_rows_total: |
| # add subtotal for each group and overall total; we start from the |
| # overall group, and iterate deeper into subgroups |
| groups = df.columns |
| for level in range(df.columns.nlevels): |
| subgroups = {group[:level] for group in groups} |
| for subgroup in subgroups: |
| slice_ = df.columns.get_loc(subgroup) |
| subtotal = pivot_v2_aggfunc_map[aggfunc](df.iloc[:, slice_], axis=1) |
| depth = df.columns.nlevels - len(subgroup) - 1 |
| total = metric_name if level == 0 else "Subtotal" |
| subtotal_name = tuple([*subgroup, total, *([""] * depth)]) |
| # insert column after subgroup |
| df.insert(int(slice_.stop), subtotal_name, subtotal) |
| |
| if rows and show_columns_total: |
| # add subtotal for each group and overall total; we start from the |
| # overall group, and iterate deeper into subgroups |
| groups = df.index |
| for level in range(df.index.nlevels): |
| subgroups = {group[:level] for group in groups} |
| for subgroup in subgroups: |
| slice_ = df.index.get_loc(subgroup) |
| subtotal = pivot_v2_aggfunc_map[aggfunc]( |
| df.iloc[slice_, :].apply(pd.to_numeric), axis=0 |
| ) |
| depth = df.index.nlevels - len(subgroup) - 1 |
| total = metric_name if level == 0 else "Subtotal" |
| subtotal.name = tuple([*subgroup, total, *([""] * depth)]) |
| # insert row after subgroup |
| df = pd.concat( |
| [df[: slice_.stop], subtotal.to_frame().T, df[slice_.stop :]] |
| ) |
| |
| # if we want to apply the metrics on the rows we need to pivot the |
| # dataframe back |
| if apply_metrics_on_rows: |
| df = df.T |
| |
| return df |
| |
| |
| def list_unique_values(series: pd.Series) -> str: |
| """ |
| List unique values in a series. |
| """ |
| return ", ".join(set(str(v) for v in pd.Series.unique(series))) |
| |
| |
| pivot_v2_aggfunc_map = { |
| "Count": pd.Series.count, |
| "Count Unique Values": pd.Series.nunique, |
| "List Unique Values": list_unique_values, |
| "Sum": pd.Series.sum, |
| "Average": pd.Series.mean, |
| "Median": pd.Series.median, |
| "Sample Variance": lambda series: pd.series.var(series) if len(series) > 1 else 0, |
| "Sample Standard Deviation": ( |
| lambda series: pd.series.std(series) if len(series) > 1 else 0, |
| ), |
| "Minimum": pd.Series.min, |
| "Maximum": pd.Series.max, |
| "First": lambda series: series[:1], |
| "Last": lambda series: series[-1:], |
| "Sum as Fraction of Total": pd.Series.sum, |
| "Sum as Fraction of Rows": pd.Series.sum, |
| "Sum as Fraction of Columns": pd.Series.sum, |
| "Count as Fraction of Total": pd.Series.count, |
| "Count as Fraction of Rows": pd.Series.count, |
| "Count as Fraction of Columns": pd.Series.count, |
| } |
| |
| |
| def pivot_table_v2(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: |
| """ |
| Pivot table v2. |
| """ |
| if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df: |
| del df[DTTM_ALIAS] |
| |
| return pivot_df( |
| df, |
| rows=form_data.get("groupbyRows") or [], |
| columns=form_data.get("groupbyColumns") or [], |
| metrics=[get_metric_name(m) for m in form_data["metrics"]], |
| aggfunc=form_data.get("aggregateFunction", "Sum"), |
| transpose_pivot=bool(form_data.get("transposePivot")), |
| combine_metrics=bool(form_data.get("combineMetric")), |
| show_rows_total=bool(form_data.get("rowTotals")), |
| show_columns_total=bool(form_data.get("colTotals")), |
| apply_metrics_on_rows=form_data.get("metricsLayout") == "ROWS", |
| ) |
| |
| |
| def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: |
| """ |
| Pivot table (v1). |
| """ |
| if form_data.get("granularity") == "all" and DTTM_ALIAS in df: |
| del df[DTTM_ALIAS] |
| |
| # v1 func names => v2 func names |
| func_map = { |
| "sum": "Sum", |
| "mean": "Average", |
| "min": "Minimum", |
| "max": "Maximum", |
| "std": "Sample Standard Deviation", |
| "var": "Sample Variance", |
| } |
| |
| return pivot_df( |
| df, |
| rows=form_data.get("groupby") or [], |
| columns=form_data.get("columns") or [], |
| metrics=[get_metric_name(m) for m in form_data["metrics"]], |
| aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"), |
| transpose_pivot=bool(form_data.get("transpose_pivot")), |
| combine_metrics=bool(form_data.get("combine_metric")), |
| show_rows_total=bool(form_data.get("pivot_margins")), |
| show_columns_total=bool(form_data.get("pivot_margins")), |
| apply_metrics_on_rows=False, |
| ) |
| |
| |
| post_processors = { |
| "pivot_table": pivot_table, |
| "pivot_table_v2": pivot_table_v2, |
| } |
| |
| |
| def apply_post_process( |
| result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, |
| ) -> Dict[Any, Any]: |
| form_data = form_data or {} |
| |
| viz_type = form_data.get("viz_type") |
| if viz_type not in post_processors: |
| return result |
| |
| post_processor = post_processors[viz_type] |
| |
| for query in result["queries"]: |
| if query["result_format"] == ChartDataResultFormat.JSON: |
| df = pd.DataFrame.from_dict(query["data"]) |
| elif query["result_format"] == ChartDataResultFormat.CSV: |
| df = pd.read_csv(StringIO(query["data"])) |
| else: |
| raise Exception(f"Result format {query['result_format']} not supported") |
| |
| processed_df = post_processor(df, form_data) |
| |
| query["colnames"] = list(processed_df.columns) |
| query["indexnames"] = list(processed_df.index) |
| query["coltypes"] = extract_dataframe_dtypes(processed_df) |
| query["rowcount"] = len(processed_df.index) |
| |
| # Flatten hierarchical columns/index since they are represented as |
| # `Tuple[str]`. Otherwise encoding to JSON later will fail because |
| # maps cannot have tuples as their keys in JSON. |
| processed_df.columns = [ |
| " ".join(str(name) for name in column).strip() |
| if isinstance(column, tuple) |
| else column |
| for column in processed_df.columns |
| ] |
| processed_df.index = [ |
| " ".join(str(name) for name in index).strip() |
| if isinstance(index, tuple) |
| else index |
| for index in processed_df.index |
| ] |
| |
| if query["result_format"] == ChartDataResultFormat.JSON: |
| query["data"] = processed_df.to_dict() |
| elif query["result_format"] == ChartDataResultFormat.CSV: |
| buf = StringIO() |
| processed_df.to_csv(buf) |
| buf.seek(0) |
| query["data"] = buf.getvalue() |
| |
| return result |