blob: 584d3869fe361845438e255a7d24cdca6611f95d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
from typing import TYPE_CHECKING, Any, List, Optional, Union
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.plot import (
PySparkPlotAccessor,
PySparkBoxPlotBase,
PySparkKdePlotBase,
PySparkHistogramPlotBase,
)
from pyspark.sql.types import NumericType
if TYPE_CHECKING:
from pyspark.sql import DataFrame
from plotly.graph_objs import Figure
def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
import plotly
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)
if kind == "hist":
return plot_histogram(data, **kwargs)
if kind not in PySparkPlotAccessor.plot_data_map:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_KIND",
messageParameters={
"plot_type": kind,
"supported_plot_types": ", ".join(
sorted(
list(PySparkPlotAccessor.plot_data_map.keys())
+ ["pie", "box", "kde", "density", "hist"]
)
),
},
)
return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)
def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
from plotly import express
pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
x = kwargs.pop("x", None)
y = kwargs.pop("y", None)
subplots = kwargs.pop("subplots", False)
if y is None and not subplots:
raise PySparkValueError(errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={})
numeric_ys = process_column_param(y, data)
if subplots:
# One pie chart per numeric column
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=len(numeric_ys),
# To accommodate domain-based trace - pie chart
specs=[[{"type": "domain"}] * len(numeric_ys)],
)
for i, y_col in enumerate(numeric_ys):
subplot_fig = express.pie(pdf, values=y_col, names=x, **kwargs)
fig.add_trace(
subplot_fig.data[0], row=1, col=i + 1
) # A single pie chart has only one trace
else:
fig = express.pie(pdf, values=numeric_ys[0], names=x, **kwargs)
return fig
def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":
import plotly.graph_objs as go
# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
# third quartiles (?). Looks they use default 1.5.
whis = kwargs.pop("whis", 1.5)
# 'precision' is pyspark specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)
colnames = process_column_param(kwargs.pop("column", None), data)
# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "boxpoints",
"value": str(boxpoints),
"supported_values": ", ".join(["suspectedoutliers", "False"]),
},
)
if notched:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "notched",
"value": str(notched),
"supported_values": ", ".join(["False"]),
},
)
fig = go.Figure()
results = PySparkBoxPlotBase.compute_box(
data,
colnames,
whis,
precision,
boxpoints is not None,
)
assert len(results) == len(colnames) # type: ignore
for i, colname in enumerate(colnames):
result = results[i] # type: ignore
fig.add_trace(
go.Box(
x=[i],
name=colname,
q1=[result["q1"]],
median=[result["med"]],
q3=[result["q3"]],
mean=[result["mean"]],
lowerfence=[result["lower_whisker"]],
upperfence=[result["upper_whisker"]],
y=[result["fliers"]] if result["fliers"] else None,
boxpoints=boxpoints,
notched=notched,
**kwargs,
)
)
fig["layout"]["yaxis"]["title"] = "value"
return fig
def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
from pyspark.testing.utils import have_numpy
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import pandas as pd
from plotly import express
if "color" not in kwargs:
kwargs["color"] = "names"
bw_method = kwargs.pop("bw_method", None)
colnames = process_column_param(kwargs.pop("column", None), data)
ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None))
if have_numpy:
import numpy as np
if isinstance(ind, np.ndarray):
ind = [float(i) for i in ind]
kde_cols = [
PySparkKdePlotBase.compute_kde_col(
input_col=data[col_name],
ind=ind,
bw_method=bw_method,
).alias(f"kde_{i}")
for i, col_name in enumerate(colnames)
]
kde_results = data.select(*kde_cols).first()
pdf = pd.concat(
[
pd.DataFrame( # type: ignore
{
"Density": kde_result,
"names": col_name,
"index": ind,
}
)
for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type]
]
)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig
def plot_histogram(data: "DataFrame", **kwargs: Any) -> "Figure":
import plotly.graph_objs as go
bins = kwargs.get("bins", 10)
colnames = process_column_param(kwargs.pop("column", None), data)
numeric_data = data.select(*colnames)
bins = PySparkHistogramPlotBase.get_bins(numeric_data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = PySparkHistogramPlotBase.compute_hist(numeric_data, bins)
prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
text_bins = []
for b in bins[1:]:
norm_b = float("%.9f" % b)
text_bins.append("[%s, %s)" % (prev, norm_b))
prev = norm_b
text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket.
bins = [(bins[i] + bins[i + 1]) / 2 for i in range(0, len(bins) - 1)]
output_series = list(output_series)
bars = []
for series in output_series:
bars.append(
go.Bar(
x=bins,
y=series,
name=series.name,
text=text_bins,
hovertemplate=("variable=" + str(series.name) + "<br>value=%{text}<br>count=%{y}"),
)
)
layout_keys = inspect.signature(go.Layout).parameters.keys()
layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys}
fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs))
fig["layout"]["barmode"] = "stack"
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig
def process_column_param(column: Optional[Union[str, List[str]]], data: "DataFrame") -> List[str]:
"""
Processes the provided column parameter for a DataFrame.
- If `column` is None, returns a list of numeric columns from the DataFrame.
- If `column` is a string, converts it to a list first.
- If `column` is a list, it checks if all specified columns exist in the DataFrame
and are of NumericType.
- Raises a PySparkTypeError if any column in the list is not present in the DataFrame
or is not of NumericType.
"""
fields_by_name = {f.name: f for f in data.schema.fields}
if column is None:
return [name for name, f in fields_by_name.items() if isinstance(f.dataType, NumericType)]
if isinstance(column, str):
column = [column]
for col in column:
field = fields_by_name.get(col)
if not field or not isinstance(field.dataType, NumericType):
raise PySparkTypeError(
errorClass="PLOT_INVALID_TYPE_COLUMN",
messageParameters={
"col_name": col,
"valid_types": NumericType.__name__,
"col_type": field.dataType.__class__.__name__ if field else "None",
},
)
return column