blob: a23c255757ee371f33156e5e519fe129305f039d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
from typing import TYPE_CHECKING, Union
import pandas as pd
from pyspark.pandas.plot import (
HistogramPlotBase,
name_like_string,
PandasOnSparkPlotAccessor,
BoxPlotBase,
KdePlotBase,
)
if TYPE_CHECKING:
import pyspark.pandas as ps
def plot_pandas_on_spark(data: Union["ps.DataFrame", "ps.Series"], kind: str, **kwargs):
import plotly
# pandas-on-Spark specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
return plot_histogram(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)
# Other plots.
return plotly.plot(PandasOnSparkPlotAccessor.pandas_plot_data_map[kind](data), kind, **kwargs)
def plot_pie(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
from plotly import express
data = PandasOnSparkPlotAccessor.pandas_plot_data_map["pie"](data)
if isinstance(data, pd.Series):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
values = kwargs.pop("y", None)
default_names = None
if values is not None:
default_names = data.index
return express.pie(
data,
values=kwargs.pop("values", values),
names=kwargs.pop("names", default_names),
**kwargs,
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps
bins = kwargs.get("bins", 10)
y = kwargs.get("y")
if y and isinstance(data, ps.DataFrame):
# Note that the results here are matched with matplotlib. x and y
# handling is different from pandas' plotly output.
data = data[y]
psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = HistogramPlotBase.compute_hist(psdf, bins)
prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
text_bins = []
for b in bins[1:]:
norm_b = float("%.9f" % b)
text_bins.append("[%s, %s)" % (prev, norm_b))
prev = norm_b
text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket.
bins = 0.5 * (bins[:-1] + bins[1:])
output_series = list(output_series)
bars = []
for series in output_series:
bars.append(
go.Bar(
x=bins,
y=series,
name=name_like_string(series.name),
text=text_bins,
hovertemplate=(
"variable=" + name_like_string(series.name) + "<br>value=%{text}<br>count=%{y}"
),
)
)
layout_keys = inspect.signature(go.Layout).parameters.keys()
layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys}
fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs))
fig["layout"]["barmode"] = "stack"
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig
def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps
from pyspark.sql.types import NumericType
# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
# third quartiles (?). Looks they use default 1.5.
whis = kwargs.pop("whis", 1.5)
# 'precision' is pandas-on-Spark specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)
# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise ValueError(
"plotly plotting backend does not support 'boxpoints' set to '%s'. "
"Set to 'suspectedoutliers' or False." % boxpoints
)
if notched:
raise ValueError(
"plotly plotting backend does not support 'notched' set to '%s'. "
"Set to False." % notched
)
fig = go.Figure()
if isinstance(data, ps.Series):
sdf = data._psdf._internal.resolved_copy.spark_frame
spark_column_name = data._internal.spark_column_name_for(data._column_label)
colnames = [spark_column_name]
else:
sdf = data._internal.resolved_copy.spark_frame
colnames = []
for column_label in data._internal.column_labels:
if isinstance(data._internal.spark_type_for(column_label), NumericType):
colnames.append(name_like_string(column_label))
results = BoxPlotBase.compute_box(
sdf,
colnames,
whis,
precision,
boxpoints is not None,
)
assert len(results) == len(colnames)
if isinstance(data, ps.Series):
colname = name_like_string(data.name)
result = results[0]
fig.add_trace(
go.Box(
name=colname,
q1=[result["q1"]],
median=[result["med"]],
q3=[result["q3"]],
mean=[result["mean"]],
lowerfence=[result["lower_whisker"]],
upperfence=[result["upper_whisker"]],
y=[result["fliers"]] if result["fliers"] else None,
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname
else:
for i, colname in enumerate(colnames):
result = results[i]
fig.add_trace(
go.Box(
x=[i],
name=colname,
q1=[result["q1"]],
median=[result["med"]],
q3=[result["q3"]],
mean=[result["mean"]],
lowerfence=[result["lower_whisker"]],
upperfence=[result["upper_whisker"]],
y=[result["fliers"]] if result["fliers"] else None,
boxpoints=boxpoints,
notched=notched,
**kwargs,
)
)
fig["layout"]["yaxis"]["title"] = "value"
return fig
def plot_kde(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
from plotly import express
import pyspark.pandas as ps
if isinstance(data, ps.DataFrame) and "color" not in kwargs:
kwargs["color"] = "names"
psdf = KdePlotBase.prepare_kde_data(data)
sdf = psdf._internal.spark_frame
data_columns = psdf._internal.data_spark_columns
ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None))
bw_method = kwargs.pop("bw_method", None)
kde_cols = [
KdePlotBase.compute_kde_col(
input_col=psdf._internal.spark_column_for(label),
ind=ind,
bw_method=bw_method,
).alias(f"kde_{i}")
for i, label in enumerate(psdf._internal.column_labels)
]
kde_results = sdf.select(*kde_cols).first()
pdf = pd.concat(
[
pd.DataFrame(
{
"Density": kde_result,
"names": name_like_string(label),
"index": ind,
}
)
for label, kde_result in zip(psdf._internal.column_labels, list(kde_results))
]
)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig