blob: 0a3a0101e18982f0bc3c43e41eb23457c43bd441 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, TYPE_CHECKING, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.utils import require_minimum_plotly_version
if TYPE_CHECKING:
from pyspark.sql import DataFrame
import pandas as pd
from plotly.graph_objs import Figure
class PySparkTopNPlotBase:
def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession
session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)
pdf = sdf.limit(max_rows + 1).toPandas()
self.partial = False
if len(pdf) > max_rows:
self.partial = True
pdf = pdf.iloc[:max_rows]
return pdf
class PySparkSampledPlotBase:
def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession, Observation, functions as F
session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)
observation = Observation("pyspark plotting")
rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__"
id_col_name = "__pyspark_plotting_sampled_plot_base_id__"
sampled_sdf = (
sdf.observe(observation, F.count(F.lit(1)).alias("count"))
.select(
"*",
F.rand().alias(rand_col_name),
F.monotonically_increasing_id().alias(id_col_name),
)
.sort(rand_col_name)
.limit(max_rows + 1)
.coalesce(1)
.sortWithinPartitions(id_col_name)
.drop(rand_col_name, id_col_name)
)
pdf = sampled_sdf.toPandas()
if len(pdf) > max_rows:
try:
self.fraction = float(max_rows) / observation.get["count"]
except Exception:
pass
return pdf[:max_rows]
else:
self.fraction = 1.0
return pdf
class PySparkPlotAccessor:
plot_data_map = {
"bar": PySparkTopNPlotBase().get_top_n,
"barh": PySparkTopNPlotBase().get_top_n,
"line": PySparkSampledPlotBase().get_sampled,
"scatter": PySparkSampledPlotBase().get_sampled,
}
_backends = {} # type: ignore[var-annotated]
def __init__(self, data: "DataFrame"):
self.data = data
def __call__(
self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
) -> "Figure":
plot_backend = PySparkPlotAccessor._get_plot_backend(backend)
return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)
@staticmethod
def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
backend = backend or "plotly"
if backend in PySparkPlotAccessor._backends:
return PySparkPlotAccessor._backends[backend]
if backend == "plotly":
require_minimum_plotly_version()
else:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND",
messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])},
)
from pyspark.sql.plot import plotly as module
return module
def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Plot DataFrame as lines.
Parameters
----------
x : str
Name of column to use for the horizontal axis.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.line(x="category", y="int_val") # doctest: +SKIP
>>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
return self(kind="line", x=x, y=y, **kwargs)
def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Vertical bar plot.
A bar plot is a plot that presents categorical data with rectangular bars with lengths
proportional to the values that they represent. A bar plot shows comparisons among
discrete categories. One axis of the plot shows the specific categories being compared,
and the other axis represents a measured value.
Parameters
----------
x : str
Name of column to use for the horizontal axis.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis.
Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP
>>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
return self(kind="bar", x=x, y=y, **kwargs)
def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Make a horizontal bar plot.
A horizontal bar plot is a plot that presents quantitative data with
rectangular bars with lengths proportional to the values that they
represent. A bar plot shows comparisons among discrete categories. One
axis of the plot shows the specific categories being compared, and the
other axis represents a measured value.
Parameters
----------
x : str or list of str
Name(s) of the column(s) to use for the horizontal axis.
Multiple columns can be plotted.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis.
Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Notes
-----
In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs.
In Plotly, `x` refers to the values and `y` refers to the categories.
In Matplotlib, `x` refers to the categories and `y` refers to the values.
Ensure correct axis labeling based on the backend used.
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP
>>> df.plot.barh(
... x=["int_val", "float_val"], y="category"
... ) # doctest: +SKIP
"""
return self(kind="barh", x=x, y=y, **kwargs)
def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure":
"""
Create a scatter plot with varying marker point size and color.
The coordinates of each point are defined by two dataframe columns and
filled circles are used to represent each point. This kind of plot is
useful to see complex correlations between two variables. Points could
be for instance natural 2D coordinates like longitude and latitude in
a map or, in general, any pair of metrics that can be plotted against
each other.
Parameters
----------
x : str
Name of column to use as horizontal coordinates for each point.
y : str or list of str
Name of column to use as vertical coordinates for each point.
**kwargs: Optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
>>> columns = ['length', 'width', 'species']
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.scatter(x='length', y='width') # doctest: +SKIP
"""
return self(kind="scatter", x=x, y=y, **kwargs)