blob: ccfe1a75424e0f16d3762ebedf82abfb09658d6c [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import pyspark.sql.plot # noqa: F401
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
@unittest.skipIf(not have_plotly, plotly_requirement_message)
class DataFramePlotPlotlyTestsMixin:
@property
def sdf(self):
data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
columns = ["category", "int_val", "float_val"]
return self.spark.createDataFrame(data, columns)
@property
def sdf2(self):
data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
columns = ["length", "width", "species"]
return self.spark.createDataFrame(data, columns)
def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""):
if kind == "line":
self.assertEqual(fig_data["mode"], "lines")
self.assertEqual(fig_data["type"], "scatter")
elif kind == "bar":
self.assertEqual(fig_data["type"], "bar")
elif kind == "barh":
self.assertEqual(fig_data["type"], "bar")
self.assertEqual(fig_data["orientation"], "h")
elif kind == "scatter":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
self.assertEqual(fig_data["yaxis"], "y")
self.assertEqual(list(fig_data["y"]), expected_y)
self.assertEqual(fig_data["name"], expected_name)
def test_line_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="line", x="category", y="int_val")
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
# multiple columns as vertical axis
fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
def test_bar_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="bar", x="category", y="int_val")
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
# multiple columns as vertical axis
fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"])
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
def test_barh_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="barh", x="category", y="int_val")
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
# multiple columns as vertical axis
fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"])
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
# multiple columns as horizontal axis
fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category")
self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val")
self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val")
def test_scatter_plot(self):
fig = self.sdf2.plot(kind="scatter", x="length", y="width")
self._check_fig_data(
"scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0]
)
fig = self.sdf2.plot.scatter(x="width", y="length")
self._check_fig_data(
"scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9]
)
class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)