[SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend
### What changes were proposed in this pull request?
Support line plot with plotly backend on both Spark Connect and Spark classic.
### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations, such as line plots, by leveraging libraries like Plotly. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.
See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.
Part of https://issues.apache.org/jira/browse/SPARK-49530.
### Does this PR introduce _any_ user-facing change?
Yes.
```python
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> sdf = spark.createDataFrame(data, columns)
>>> sdf.show()
+--------+-------+---------+
|category|int_val|float_val|
+--------+-------+---------+
| A| 10| 1.5|
| B| 30| 2.5|
| C| 20| 3.5|
+--------+-------+---------+
>>> f = sdf.plot(kind="line", x="category", y="int_val")
>>> f.show() # see below
>>> g = sdf.plot.line(x="category", y=["int_val", "float_val"])
>>> g.show() # see below
```
`f.show()`:

`g.show()`:

### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48139 from xinrong-meng/plot_line_w_dep.
Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml
index 3ac1a01..f668d81 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -71,7 +71,7 @@
python packaging/connect/setup.py sdist
cd dist
pip install pyspark*connect-*.tar.gz
- pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting
+ pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8'
- name: Run tests
env:
SPARK_TESTING: 1
diff --git a/dev/requirements.txt b/dev/requirements.txt
index 5486c98..cafc734 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -7,7 +7,7 @@
six==1.16.0
pandas>=2.0.0
scipy
-plotly
+plotly>=4.8
mlflow>=2.3.1
scikit-learn
matplotlib
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 34fbb84..b9a4bed 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -548,6 +548,8 @@
"pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_utils",
"pyspark.sql.tests.test_resources",
+ "pyspark.sql.tests.plot.test_frame_plot",
+ "pyspark.sql.tests.plot.test_frame_plot_plotly",
],
)
@@ -1051,6 +1053,8 @@
"pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
"pyspark.sql.tests.connect.test_parity_python_datasource",
"pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
+ "pyspark.sql.tests.connect.test_parity_frame_plot",
+ "pyspark.sql.tests.connect.test_parity_frame_plot_plotly",
"pyspark.sql.tests.connect.test_utils",
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_artifact_localcluster",
diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index 549656b..88c0a8c 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -183,6 +183,7 @@
Additional libraries that enhance functionality but are not included in the installation packages:
- **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``.
+- **plotly**: Used for PySpark plotting, ``DataFrame.plot``.
Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_.
diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py
index 79b7448..17cca32 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -288,6 +288,7 @@
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
+ "pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py
index ab166c7..6ae16e9 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -77,6 +77,7 @@
"pyspark.sql.tests.connect.client",
"pyspark.sql.tests.connect.shell",
"pyspark.sql.tests.pandas",
+ "pyspark.sql.tests.plot",
"pyspark.sql.tests.streaming",
"pyspark.ml.tests.connect",
"pyspark.pandas.tests",
@@ -161,6 +162,7 @@
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
+ "pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json
index 4061d02..92aeb15 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1088,6 +1088,11 @@
"Function `<func_name>` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
]
},
+ "UNSUPPORTED_PLOT_BACKEND": {
+ "message": [
+ "`<backend>` is not supported, it should be one of the values from <supported_backends>"
+ ]
+ },
"UNSUPPORTED_SIGNATURE": {
"message": [
"Unsupported signature: <signature>."
diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py
index 91b9591..a2778cb 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -73,6 +73,11 @@
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+try:
+ from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+ PySparkPlotAccessor = None # type: ignore
+
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
@@ -1862,6 +1867,10 @@
messageParameters={"member": "queryExecution"},
)
+ @property
+ def plot(self) -> PySparkPlotAccessor:
+ return PySparkPlotAccessor(self)
+
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 768abd6..59d79de 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -86,6 +86,10 @@
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
+try:
+ from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+ PySparkPlotAccessor = None # type: ignore
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
@@ -2239,6 +2243,10 @@
def executionInfo(self) -> Optional["ExecutionInfo"]:
return self._execution_info
+ @property
+ def plot(self) -> PySparkPlotAccessor:
+ return PySparkPlotAccessor(self)
+
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ef35b73..2179a84 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -43,6 +43,7 @@
from pyspark.sql.types import StructType, Row
from pyspark.sql.utils import dispatch_df_method
+
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
@@ -65,6 +66,7 @@
ArrowMapIterFunction,
DataFrameLike as PandasDataFrameLike,
)
+ from pyspark.sql.plot import PySparkPlotAccessor
from pyspark.sql.metrics import ExecutionInfo
@@ -6394,6 +6396,32 @@
"""
...
+ @property
+ def plot(self) -> "PySparkPlotAccessor":
+ """
+ Returns a :class:`PySparkPlotAccessor` for plotting functions.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ :class:`PySparkPlotAccessor`
+
+ Notes
+ -----
+ This API is experimental.
+
+ Examples
+ --------
+ >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ >>> columns = ["category", "int_val", "float_val"]
+ >>> df = spark.createDataFrame(data, columns)
+ >>> type(df.plot)
+ <class 'pyspark.sql.plot.core.PySparkPlotAccessor'>
+ >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
+ """
+ ...
+
class DataFrameNaFunctions:
"""Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/plot/__init__.py b/python/pyspark/sql/plot/__init__.py
new file mode 100644
index 0000000..6da0706
--- /dev/null
+++ b/python/pyspark/sql/plot/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+This package includes the plotting APIs for PySpark DataFrame.
+"""
+from pyspark.sql.plot.core import * # noqa: F403, F401
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
new file mode 100644
index 0000000..392ef73
--- /dev/null
+++ b/python/pyspark/sql/plot/core.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, TYPE_CHECKING, Optional, Union
+from types import ModuleType
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.sql.utils import require_minimum_plotly_version
+
+
+if TYPE_CHECKING:
+ from pyspark.sql import DataFrame
+ import pandas as pd
+ from plotly.graph_objs import Figure
+
+
+class PySparkTopNPlotBase:
+ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
+ from pyspark.sql import SparkSession
+
+ session = SparkSession.getActiveSession()
+ if session is None:
+ raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
+
+ max_rows = int(
+ session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
+ )
+ pdf = sdf.limit(max_rows + 1).toPandas()
+
+ self.partial = False
+ if len(pdf) > max_rows:
+ self.partial = True
+ pdf = pdf.iloc[:max_rows]
+
+ return pdf
+
+
+class PySparkSampledPlotBase:
+ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
+ from pyspark.sql import SparkSession
+
+ session = SparkSession.getActiveSession()
+ if session is None:
+ raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
+
+ sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
+ max_rows = int(
+ session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
+ )
+
+ if sample_ratio is None:
+ fraction = 1 / (sdf.count() / max_rows)
+ fraction = min(1.0, fraction)
+ else:
+ fraction = float(sample_ratio)
+
+ sampled_sdf = sdf.sample(fraction=fraction)
+ pdf = sampled_sdf.toPandas()
+
+ return pdf
+
+
+class PySparkPlotAccessor:
+ plot_data_map = {
+ "line": PySparkSampledPlotBase().get_sampled,
+ }
+ _backends = {} # type: ignore[var-annotated]
+
+ def __init__(self, data: "DataFrame"):
+ self.data = data
+
+ def __call__(
+ self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
+ ) -> "Figure":
+ plot_backend = PySparkPlotAccessor._get_plot_backend(backend)
+
+ return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)
+
+ @staticmethod
+ def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
+ backend = backend or "plotly"
+
+ if backend in PySparkPlotAccessor._backends:
+ return PySparkPlotAccessor._backends[backend]
+
+ if backend == "plotly":
+ require_minimum_plotly_version()
+ else:
+ raise PySparkValueError(
+ errorClass="UNSUPPORTED_PLOT_BACKEND",
+ messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])},
+ )
+ from pyspark.sql.plot import plotly as module
+
+ return module
+
+ def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
+ """
+ Plot DataFrame as lines.
+
+ Parameters
+ ----------
+ x : str
+ Name of column to use for the horizontal axis.
+ y : str or list of str
+ Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted.
+ **kwargs : optional
+ Additional keyword arguments.
+
+ Returns
+ -------
+ :class:`plotly.graph_objs.Figure`
+
+ Examples
+ --------
+ >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ >>> columns = ["category", "int_val", "float_val"]
+ >>> df = spark.createDataFrame(data, columns)
+ >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP
+ >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
+ """
+ return self(kind="line", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py
new file mode 100644
index 0000000..5efc194
--- /dev/null
+++ b/python/pyspark/sql/plot/plotly.py
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import TYPE_CHECKING, Any
+
+from pyspark.sql.plot import PySparkPlotAccessor
+
+if TYPE_CHECKING:
+ from pyspark.sql import DataFrame
+ from plotly.graph_objs import Figure
+
+
+def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
+ import plotly
+
+ return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
new file mode 100644
index 0000000..c69e438
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin
+
+
+class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
new file mode 100644
index 0000000..78508fe
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot_plotly import DataFramePlotPlotlyTestsMixin
+
+
+class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/__init__.py b/python/pyspark/sql/tests/plot/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py
new file mode 100644
index 0000000..f753b5a
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pyspark.errors import PySparkValueError
+from pyspark.sql import Row
+from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
+
+
+@unittest.skipIf(not have_plotly, plotly_requirement_message)
+class DataFramePlotTestsMixin:
+ def test_backend(self):
+ accessor = self.spark.range(2).plot
+ backend = accessor._get_plot_backend()
+ self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly")
+
+ with self.assertRaises(PySparkValueError) as pe:
+ accessor._get_plot_backend("matplotlib")
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="UNSUPPORTED_PLOT_BACKEND",
+ messageParameters={"backend": "matplotlib", "supported_backends": "plotly"},
+ )
+
+ def test_topn_max_rows(self):
+ try:
+ self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000")
+ sdf = self.spark.range(2500)
+ pdf = PySparkTopNPlotBase().get_top_n(sdf)
+ self.assertEqual(len(pdf), 1000)
+ finally:
+ self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows")
+
+ def test_sampled_plot_with_ratio(self):
+ try:
+ self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5")
+ data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)]
+ sdf = self.spark.createDataFrame(data)
+ pdf = PySparkSampledPlotBase().get_sampled(sdf)
+ self.assertEqual(round(len(pdf) / 2500, 1), 0.5)
+ finally:
+ self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio")
+
+ def test_sampled_plot_with_max_rows(self):
+ data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)]
+ sdf = self.spark.createDataFrame(data)
+ pdf = PySparkSampledPlotBase().get_sampled(sdf)
+ self.assertEqual(round(len(pdf) / 2000, 1), 0.5)
+
+
+class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.plot.test_frame_plot import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
new file mode 100644
index 0000000..72a3ed2
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import pyspark.sql.plot # noqa: F401
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
+
+
+@unittest.skipIf(not have_plotly, plotly_requirement_message)
+class DataFramePlotPlotlyTestsMixin:
+ @property
+ def sdf(self):
+ data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ columns = ["category", "int_val", "float_val"]
+ return self.spark.createDataFrame(data, columns)
+
+ def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""):
+ self.assertEqual(fig_data["mode"], "lines")
+ self.assertEqual(fig_data["type"], "scatter")
+ self.assertEqual(fig_data["xaxis"], "x")
+ self.assertEqual(list(fig_data["x"]), expected_x)
+ self.assertEqual(fig_data["yaxis"], "y")
+ self.assertEqual(list(fig_data["y"]), expected_y)
+ self.assertEqual(fig_data["name"], expected_name)
+
+ def test_line_plot(self):
+ # single column as vertical axis
+ fig = self.sdf.plot(kind="line", x="category", y="int_val")
+ self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20])
+
+ # multiple columns as vertical axis
+ fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
+ self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
+ self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
+
+
+class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 11b9161..5d9ec92 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -41,6 +41,7 @@
PythonException,
UnknownException,
SparkUpgradeException,
+ PySparkImportError,
PySparkNotImplementedError,
PySparkRuntimeError,
)
@@ -115,6 +116,22 @@
)
+def require_minimum_plotly_version() -> None:
+ """Raise ImportError if plotly is not installed"""
+ minimum_plotly_version = "4.8"
+
+ try:
+ import plotly # noqa: F401
+ except ImportError as error:
+ raise PySparkImportError(
+ errorClass="PACKAGE_NOT_INSTALLED",
+ messageParameters={
+ "package_name": "plotly",
+ "minimum_version": str(minimum_plotly_version),
+ },
+ ) from error
+
+
class ForeachBatchFunction:
"""
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 9f07c44..00ad40e 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -48,6 +48,13 @@
except Exception as e:
test_not_compiled_message = str(e)
+plotly_requirement_message = None
+try:
+ import plotly
+except ImportError as e:
+ plotly_requirement_message = str(e)
+have_plotly = plotly_requirement_message is None
+
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2eaafde..6c3e9ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3169,6 +3169,29 @@
.version("4.0.0")
.fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED)
+ val PYSPARK_PLOT_MAX_ROWS =
+ buildConf("spark.sql.pyspark.plotting.max_rows")
+ .doc(
+ "The visual limit on top-n-based plots. If set to 1000, the first 1000 data points " +
+ "will be used for plotting.")
+ .version("4.0.0")
+ .intConf
+ .createWithDefault(1000)
+
+ val PYSPARK_PLOT_SAMPLE_RATIO =
+ buildConf("spark.sql.pyspark.plotting.sample_ratio")
+ .doc(
+ "The proportion of data that will be plotted for sample-based plots. It is determined " +
+ "based on spark.sql.pyspark.plotting.max_rows if not explicitly set."
+ )
+ .version("4.0.0")
+ .doubleConf
+ .checkValue(
+ ratio => ratio >= 0.0 && ratio <= 1.0,
+ "The value should be between 0.0 and 1.0 inclusive."
+ )
+ .createOptional
+
val ARROW_SPARKR_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.sparkr.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " +
@@ -5873,6 +5896,10 @@
def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
+ def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
+
+ def pysparkPlotSampleRatio: Option[Double] = getConf(PYSPARK_PLOT_SAMPLE_RATIO)
+
def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED)