blob: 3dafd71c1a3291c2418737e04ef578da4e67dd06 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from datetime import datetime
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_plotly,
plotly_requirement_message,
have_pandas,
pandas_requirement_message,
)
if have_plotly and have_pandas:
import pyspark.sql.plot # noqa: F401
@unittest.skipIf(
not have_plotly or not have_pandas, plotly_requirement_message or pandas_requirement_message
)
class DataFramePlotPlotlyTestsMixin:
@property
def sdf(self):
data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
columns = ["category", "int_val", "float_val"]
return self.spark.createDataFrame(data, columns)
@property
def sdf2(self):
data = [(5.1, 3.5, "0"), (4.9, 3.0, "0"), (7.0, 3.2, "1"), (6.4, 3.2, "1"), (5.9, 3.0, "2")]
columns = ["length", "width", "species"]
return self.spark.createDataFrame(data, columns)
@property
def sdf3(self):
data = [
(3, 5, 20, datetime(2018, 1, 31)),
(2, 5, 42, datetime(2018, 2, 28)),
(3, 6, 28, datetime(2018, 3, 31)),
(9, 12, 62, datetime(2018, 4, 30)),
]
columns = ["sales", "signups", "visits", "date"]
return self.spark.createDataFrame(data, columns)
@property
def sdf4(self):
data = [
("A", 50, 55),
("B", 55, 60),
("C", 60, 65),
("D", 65, 70),
("E", 70, 75),
# outliers
("F", 10, 15),
("G", 85, 90),
("H", 5, 150),
]
columns = ["student", "math_score", "english_score"]
return self.spark.createDataFrame(data, columns)
def _check_fig_data(self, fig_data, **kwargs):
for key, expected_value in kwargs.items():
if key in ["x", "y", "labels", "values"]:
converted_values = [v.item() if hasattr(v, "item") else v for v in fig_data[key]]
self.assertEqual(converted_values, expected_value)
else:
self.assertEqual(fig_data[key], expected_value)
def test_line_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="line", x="category", y="int_val")
expected_fig_data = {
"mode": "lines",
"name": "",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
# multiple columns as vertical axis
fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
expected_fig_data = {
"mode": "lines",
"name": "int_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"mode": "lines",
"name": "float_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
def test_bar_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="bar", x="category", y="int_val")
expected_fig_data = {
"name": "",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
# multiple columns as vertical axis
fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"])
expected_fig_data = {
"name": "int_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
def test_barh_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="barh", x="category", y="int_val")
expected_fig_data = {
"name": "",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
# multiple columns as vertical axis
fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"])
expected_fig_data = {
"name": "int_val",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
# multiple columns as horizontal axis
fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category")
expected_fig_data = {
"name": "int_val",
"orientation": "h",
"y": ["A", "B", "C"],
"xaxis": "x",
"x": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "h",
"y": ["A", "B", "C"],
"xaxis": "x",
"x": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
def test_scatter_plot(self):
fig = self.sdf2.plot(kind="scatter", x="length", y="width")
expected_fig_data = {
"name": "",
"orientation": "v",
"x": [5.1, 4.9, 7.0, 6.4, 5.9],
"xaxis": "x",
"y": [3.5, 3.0, 3.2, 3.2, 3.0],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "",
"orientation": "v",
"y": [5.1, 4.9, 7.0, 6.4, 5.9],
"xaxis": "x",
"x": [3.5, 3.0, 3.2, 3.2, 3.0],
"yaxis": "y",
"type": "scatter",
}
fig = self.sdf2.plot.scatter(x="width", y="length")
self._check_fig_data(fig["data"][0], **expected_fig_data)
def test_area_plot(self):
# single column as vertical axis
fig = self.sdf3.plot(kind="area", x="date", y="sales")
expected_x = [
datetime(2018, 1, 31, 0, 0),
datetime(2018, 2, 28, 0, 0),
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
expected_fig_data = {
"name": "",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [3, 2, 3, 9],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
# multiple columns as vertical axis
fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"])
expected_fig_data = {
"name": "sales",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [3, 2, 3, 9],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "signups",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [5, 5, 6, 12],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
expected_fig_data = {
"name": "visits",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [20, 42, 28, 62],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][2], **expected_fig_data)
def test_pie_plot(self):
# single column as 'y'
fig = self.sdf3.plot(kind="pie", x="date", y="sales")
expected_x = [
datetime(2018, 1, 31, 0, 0),
datetime(2018, 2, 28, 0, 0),
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
expected_fig_data_sales = {
"name": "",
"labels": expected_x,
"values": [3, 2, 3, 9],
"type": "pie",
}
self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
# all numeric columns as 'y'
expected_fig_data_signups = {
"name": "",
"labels": expected_x,
"values": [5, 5, 6, 12],
"type": "pie",
}
expected_fig_data_visits = {
"name": "",
"labels": expected_x,
"values": [20, 42, 28, 62],
"type": "pie",
}
fig = self.sdf3.plot(kind="pie", x="date", subplots=True)
self._check_fig_data(fig["data"][0], **expected_fig_data_sales)
self._check_fig_data(fig["data"][1], **expected_fig_data_signups)
self._check_fig_data(fig["data"][2], **expected_fig_data_visits)
# not specify subplots
with self.assertRaises(PySparkValueError) as pe:
self.sdf3.plot(kind="pie", x="date")
self.check_error(
exception=pe.exception, errorClass="UNSUPPORTED_PIE_PLOT_PARAM", messageParameters={}
)
# y is not a numerical column
with self.assertRaises(PySparkTypeError) as pe:
self.sdf.plot.pie(x="int_val", y="category")
self.check_error(
exception=pe.exception,
errorClass="PLOT_INVALID_TYPE_COLUMN",
messageParameters={
"col_name": "category",
"valid_types": "NumericType",
"col_type": "StringType",
},
)
def test_box_plot(self):
fig = self.sdf4.plot.box(column="math_score")
expected_fig_data1 = {
"boxpoints": "suspectedoutliers",
"lowerfence": (5,),
"mean": (50.0,),
"median": (55,),
"name": "math_score",
"notched": False,
"q1": (10,),
"q3": (65,),
"upperfence": (85,),
"x": [0],
"type": "box",
}
self._check_fig_data(fig["data"][0], **expected_fig_data1)
fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"])
self._check_fig_data(fig["data"][0], **expected_fig_data1)
expected_fig_data2 = {
"boxpoints": "suspectedoutliers",
"lowerfence": (55,),
"mean": (72.5,),
"median": (65,),
"name": "english_score",
"notched": False,
"q1": (55,),
"q3": (75,),
"upperfence": (90,),
"x": [1],
"y": [[150, 15]],
"type": "box",
}
self._check_fig_data(fig["data"][1], **expected_fig_data2)
fig = self.sdf4.plot(kind="box")
self._check_fig_data(fig["data"][0], **expected_fig_data1)
self._check_fig_data(fig["data"][1], **expected_fig_data2)
with self.assertRaises(PySparkValueError) as pe:
self.sdf4.plot.box(column="math_score", boxpoints=True)
self.check_error(
exception=pe.exception,
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "boxpoints",
"value": "True",
"supported_values": ", ".join(["suspectedoutliers", "False"]),
},
)
with self.assertRaises(PySparkValueError) as pe:
self.sdf4.plot.box(column="math_score", notched=True)
self.check_error(
exception=pe.exception,
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "notched",
"value": "True",
"supported_values": ", ".join(["False"]),
},
)
def test_kde_plot(self):
fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5)
expected_fig_data1 = {
"mode": "lines",
"name": "math_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data1)
fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5)
self._check_fig_data(fig["data"][0], **expected_fig_data1)
expected_fig_data2 = {
"mode": "lines",
"name": "english_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data2)
self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"]))
fig = self.sdf4.plot.kde(bw_method=0.3, ind=5)
self._check_fig_data(fig["data"][0], **expected_fig_data1)
self._check_fig_data(fig["data"][1], **expected_fig_data2)
self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"]))
def test_hist_plot(self):
fig = self.sdf2.plot.hist(column="length", bins=4)
expected_fig_data = {
"name": "length",
"x": [5.1625000000000005, 5.6875, 6.2125, 6.7375],
"y": [2, 1, 1, 1],
"text": ("[4.9, 5.425)", "[5.425, 5.95)", "[5.95, 6.475)", "[6.475, 7.0]"),
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
fig = self.sdf2.plot.hist(column=["length", "width"], bins=4)
expected_fig_data1 = {
"name": "length",
"x": [3.5, 4.5, 5.5, 6.5],
"y": [0, 1, 2, 2],
"text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"),
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data1)
expected_fig_data2 = {
"name": "width",
"x": [3.5, 4.5, 5.5, 6.5],
"y": [5, 0, 0, 0],
"text": ("[3.0, 4.0)", "[4.0, 5.0)", "[5.0, 6.0)", "[6.0, 7.0]"),
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data2)
fig = self.sdf2.plot.hist(bins=4)
self._check_fig_data(fig["data"][0], **expected_fig_data1)
self._check_fig_data(fig["data"][1], **expected_fig_data2)
def test_process_column_param_errors(self):
with self.assertRaises(PySparkTypeError) as pe:
self.sdf4.plot.box(column="math_scor")
self.check_error(
exception=pe.exception,
errorClass="PLOT_INVALID_TYPE_COLUMN",
messageParameters={
"col_name": "math_scor",
"valid_types": "NumericType",
"col_type": "None",
},
)
with self.assertRaises(PySparkTypeError) as pe:
self.sdf4.plot.box(column="student")
self.check_error(
exception=pe.exception,
errorClass="PLOT_INVALID_TYPE_COLUMN",
messageParameters={
"col_name": "student",
"valid_types": "NumericType",
"col_type": "StringType",
},
)
class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)