blob: 3fe3a3f83d7e20e719d7b7f685d2a67fed28f2e7 [file] [log] [blame]
# -*- encoding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import difflib
from itertools import zip_longest
from pyspark.errors import QueryContextType
from pyspark.errors import (
AnalysisException,
ParseException,
PySparkAssertionError,
PySparkValueError,
IllegalArgumentException,
SparkUpgradeException,
)
from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual, _context_diff, have_numpy
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.sql import Row
import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
from pyspark.sql.types import (
StringType,
ArrayType,
LongType,
StructType,
MapType,
FloatType,
DoubleType,
StructField,
IntegerType,
BooleanType,
)
from pyspark.testing.sqlutils import have_pandas, have_pyarrow
class UtilsTestsMixin:
def test_assert_equal_inttype(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 3000),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 3000),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_arraytype(self):
df1 = self.spark.createDataFrame(
data=[
("john", ["Python", "Java"]),
("jane", ["Scala", "SQL", "Java"]),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField("languages", ArrayType(StringType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("john", ["Python", "Java"]),
("jane", ["Scala", "SQL", "Java"]),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField("languages", ArrayType(StringType()), True),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_approx_equal_arraytype_float(self):
df1 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.34]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.339999]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_approx_equal_arraytype_float_default_rtol_fail(self):
# fails with default rtol, 1e-5
df1 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.34]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.341]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
)
error_msg = "Results do not match: "
percent_diff = (1 / 2) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_assert_approx_equal_arraytype_float_custom_rtol_pass(self):
# passes with custom rtol, 1e-2
df1 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.34]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", [97.01, 89.23]),
("student2", [91.86, 84.341]),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", ArrayType(FloatType()), True),
]
),
)
assertDataFrameEqual(df1, df2, rtol=1e-2)
def test_assert_approx_equal_doubletype_custom_rtol_pass(self):
# passes with custom rtol, 1e-2
df1 = self.spark.createDataFrame(
data=[
("student1", 97.01),
("student2", 84.34),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grade", DoubleType(), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", 97.01),
("student2", 84.341),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grade", DoubleType(), True),
]
),
)
assertDataFrameEqual(df1, df2, rtol=1e-2)
def test_assert_approx_equal_decimaltype_custom_rtol_pass(self):
# passes with custom rtol, 1e-2
df1 = self.spark.createDataFrame(
data=[
("student1", 83.14),
("student2", 97.12),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grade", DoubleType(), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", 83.14),
("student2", 97.111),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grade", DoubleType(), True),
]
),
)
# cast to DecimalType
df1 = df1.withColumn("col_1", F.col("grade").cast("decimal(5,3)"))
df2 = df2.withColumn("col_1", F.col("grade").cast("decimal(5,3)"))
assertDataFrameEqual(df1, df2, rtol=1e-1)
def test_assert_notequal_arraytype(self):
df1 = self.spark.createDataFrame(
data=[
("Amy", ["C++", "Rust"]),
("John", ["Python", "Java"]),
("Jane", ["Scala", "SQL", "Java"]),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField("languages", ArrayType(StringType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("Amy", ["C++", "Rust"]),
("John", ["Python", "Java"]),
("Jane", ["Scala", "Java"]),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField("languages", ArrayType(StringType()), True),
]
),
)
rows_str1 = ""
rows_str2 = ""
sorted_list1 = sorted(df1.collect(), key=lambda x: str(x))
sorted_list2 = sorted(df2.collect(), key=lambda x: str(x))
# count different rows
for r1, r2 in list(zip_longest(sorted_list1, sorted_list2)):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
)
error_msg = "Results do not match: "
percent_diff = (1 / 3) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
)
error_msg = "Results do not match: "
percent_diff = (1 / 3) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_assert_equal_maptype(self):
df1 = self.spark.createDataFrame(
data=[
("student1", {"id": 222342203655477580}),
("student2", {"id": 422322203155477692}),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("properties", MapType(StringType(), LongType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", {"id": 222342203655477580}),
("student2", {"id": 422322203155477692}),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("properties", MapType(StringType(), LongType()), True),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_approx_equal_maptype_double(self):
df1 = self.spark.createDataFrame(
data=[
("student1", {"math": 76.23, "english": 92.64}),
("student2", {"math": 87.89, "english": 84.48}),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", MapType(StringType(), DoubleType()), True),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("student1", {"math": 76.23, "english": 92.63999999}),
("student2", {"math": 87.89, "english": 84.48}),
],
schema=StructType(
[
StructField("student", StringType(), True),
StructField("grades", MapType(StringType(), DoubleType()), True),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_approx_equal_nested_struct_double(self):
df1 = self.spark.createDataFrame(
data=[
("jane", (64.57, 76.63, 97.81)),
("john", (93.92, 91.57, 84.36)),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField(
"grades",
StructType(
[
StructField("math", DoubleType(), True),
StructField("english", DoubleType(), True),
StructField("biology", DoubleType(), True),
]
),
),
]
),
)
df2 = self.spark.createDataFrame(
data=[
("jane", (64.57, 76.63, 97.81000001)),
("john", (93.92, 91.57, 84.36)),
],
schema=StructType(
[
StructField("name", StringType(), True),
StructField(
"grades",
StructType(
[
StructField("math", DoubleType(), True),
StructField("english", DoubleType(), True),
StructField("biology", DoubleType(), True),
]
),
),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_nested_struct_str(self):
df1 = self.spark.createDataFrame(
data=[
(1, ("jane", "anne", "doe")),
(2, ("john", "bob", "smith")),
],
schema=StructType(
[
StructField("id", IntegerType(), True),
StructField(
"name",
StructType(
[
StructField("first", StringType(), True),
StructField("middle", StringType(), True),
StructField("last", StringType(), True),
]
),
),
]
),
)
df2 = self.spark.createDataFrame(
data=[
(1, ("jane", "anne", "doe")),
(2, ("john", "bob", "smith")),
],
schema=StructType(
[
StructField("id", IntegerType(), True),
StructField(
"name",
StructType(
[
StructField("first", StringType(), True),
StructField("middle", StringType(), True),
StructField("last", StringType(), True),
]
),
),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_nested_struct_str_duplicate(self):
df1 = self.spark.createDataFrame(
data=[
(1, ("jane doe", "jane doe")),
(2, ("john smith", "john smith")),
],
schema=StructType(
[
StructField("id", IntegerType(), True),
StructField(
"full name",
StructType(
[
StructField("name", StringType(), True),
StructField("name", StringType(), True),
]
),
),
]
),
)
df2 = self.spark.createDataFrame(
data=[
(1, ("jane doe", "jane doe")),
(2, ("john smith", "john smith")),
],
schema=StructType(
[
StructField("id", IntegerType(), True),
StructField(
"full name",
StructType(
[
StructField("name", StringType(), True),
StructField("name", StringType(), True),
]
),
),
]
),
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_duplicate_col(self):
df1 = self.spark.createDataFrame(
data=[
(1, "Python", 1, 1),
(2, "Scala", 2, 2),
],
schema=["number", "language", "number", "number"],
)
df2 = self.spark.createDataFrame(
data=[
(1, "Python", 1, 1),
(2, "Scala", 2, 2),
],
schema=["number", "language", "number", "number"],
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_timestamp(self):
df1 = self.spark.createDataFrame(
data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
)
df2 = self.spark.createDataFrame(
data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
)
df1 = df1.withColumn("timestamp", F.to_timestamp("timestamp"))
df2 = df2.withColumn("timestamp", F.to_timestamp("timestamp"))
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_equal_nullrow(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000),
(None, None),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000),
(None, None),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_notequal_nullval(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 2000),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", None),
],
schema=["id", "amount"],
)
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
)
error_msg = "Results do not match: "
percent_diff = (1 / 2) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_assert_equal_nulldf(self):
df1 = None
df2 = None
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_unequal_null_actual(self):
df1 = None
df2 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 3000),
],
schema=["id", "amount"],
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "actual",
"actual_type": None,
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "actual",
"actual_type": None,
},
)
def test_assert_unequal_null_expected(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 3000),
],
schema=["id", "amount"],
)
df2 = None
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "expected",
"actual_type": None,
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "expected",
"actual_type": None,
},
)
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_equal_exact_pandas_df(self):
import pandas as pd
import numpy as np
df1 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"]
)
df2 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"]
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_approx_equal_pandas_df(self):
import pandas as pd
import numpy as np
# test that asserts close enough equality for pandas df
df1 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59)]), columns=["a", "b", "c"]
)
df2 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59.0001)]), columns=["a", "b", "c"]
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_approx_equal_fail_exact_pandas_df(self):
import pandas as pd
import numpy as np
# test that asserts close enough equality for pandas df
df1 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59)]), columns=["a", "b", "c"]
)
df2 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59.0001)]), columns=["a", "b", "c"]
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=False, rtol=0, atol=0)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True, rtol=0, atol=0)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_unequal_pandas_df(self):
import pandas as pd
import numpy as np
df1 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (6, 5, 4)]), columns=["a", "b", "c"]
)
df2 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"]
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=False)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
@unittest.skipIf(
not have_pandas or not have_numpy or not have_pyarrow,
"no pandas or numpy or pyarrow dependency",
)
def test_assert_type_error_pandas_df(self):
import pyspark.pandas as ps
import pandas as pd
import numpy as np
df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
df2 = pd.DataFrame(
data=np.array([(1, 2, 3), (4, 5, 6), (6, 5, 4)]), columns=["a", "b", "c"]
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=False)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_PANDAS_DATAFRAME",
message_parameters={
"left": df1.to_string(),
"left_dtype": str(df1.dtypes),
"right": df2.to_string(),
"right_dtype": str(df2.dtypes),
},
)
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_exact_pandas_on_spark_df(self):
import pyspark.pandas as ps
df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_exact_pandas_on_spark_df(self):
import pyspark.pandas as ps
df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"])
assertDataFrameEqual(df1, df2)
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_equal_approx_pandas_on_spark_df(self):
import pyspark.pandas as ps
df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"])
df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"])
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
@unittest.skipIf(not have_pandas or not have_pyarrow, "no pandas or pyarrow dependency")
def test_assert_error_pandas_pyspark_df(self):
import pyspark.pandas as ps
import pandas as pd
df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
df2 = self.spark.createDataFrame([(10,), (11,), (13,)], ["Numbers"])
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=False)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": f"{ps.DataFrame.__name__}, "
f"{pd.DataFrame.__name__}, "
f"{ps.Series.__name__}, "
f"{pd.Series.__name__}, "
f"{ps.Index.__name__}"
f"{pd.Index.__name__}, ",
"arg_name": "expected",
"actual_type": type(df2),
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": f"{ps.DataFrame.__name__}, "
f"{pd.DataFrame.__name__}, "
f"{ps.Series.__name__}, "
f"{pd.Series.__name__}, "
f"{ps.Index.__name__}"
f"{pd.Index.__name__}, ",
"arg_name": "expected",
"actual_type": type(df2),
},
)
def test_assert_error_non_pyspark_df(self):
dict1 = {"a": 1, "b": 2}
dict2 = {"a": 1, "b": 2}
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(dict1, dict2)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "actual",
"actual_type": type(dict1),
},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(dict1, dict2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
"expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]",
"arg_name": "actual",
"actual_type": type(dict1),
},
)
def test_row_order_ignored(self):
# test that row order is ignored (not checked) by default
df1 = self.spark.createDataFrame(
data=[
("2", 3000.00),
("1", 1000.00),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000.00),
("2", 3000.00),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2)
def test_check_row_order_error(self):
# test checkRowOrder=True
df1 = self.spark.createDataFrame(
data=[
("2", 3000.00),
("1", 1000.00),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000.00),
("2", 3000.00),
],
schema=["id", "amount"],
)
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
)
error_msg = "Results do not match: "
percent_diff = (2 / 2) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_remove_non_word_characters_long(self):
def remove_non_word_characters(col):
return F.regexp_replace(col, "[^\\w\\s]+", "")
source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)]
source_df = self.spark.createDataFrame(source_data, ["name"])
actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name")))
expected_data = [("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", "luisa"), (None, None)]
expected_df = self.spark.createDataFrame(expected_data, ["name", "clean_name"])
assertDataFrameEqual(actual_df, expected_df)
def test_assert_pyspark_approx_equal(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000.00),
("2", 3000.00),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000.0000001),
("2", 3000.00),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_assert_pyspark_approx_equal_custom_rtol(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000.00),
("2", 3000.00),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000.01),
("2", 3000.00),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2, rtol=1e-2)
def test_assert_pyspark_df_not_equal(self):
df1 = self.spark.createDataFrame(
data=[
("1", 1000.00),
("2", 3000.00),
("3", 2000.00),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1001.00),
("2", 3000.00),
("3", 2003.00),
],
schema=["id", "amount"],
)
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
)
error_msg = "Results do not match: "
percent_diff = (2 / 3) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_assert_notequal_schema(self):
df1 = self.spark.createDataFrame(
data=[
(1, 1000),
(2, 3000),
],
schema=["id", "number"],
)
df2 = self.spark.createDataFrame(
data=[
("1", 1000),
("2", 5000),
],
schema=["id", "amount"],
)
generated_diff = difflib.ndiff(str(df1.schema).splitlines(), str(df2.schema).splitlines())
expected_error_msg = "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_SCHEMA",
message_parameters={"error_msg": expected_error_msg},
)
def test_diff_schema_lens(self):
df1 = self.spark.createDataFrame(
data=[
(1, 3000),
(2, 1000),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
(1, 3000, "a"),
(2, 1000, "b"),
],
schema=["id", "amount", "letter"],
)
generated_diff = difflib.ndiff(str(df1.schema).splitlines(), str(df2.schema).splitlines())
expected_error_msg = "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, df2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_SCHEMA",
message_parameters={"error_msg": expected_error_msg},
)
def test_schema_ignore_nullable(self):
s1 = StructType(
[StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
)
df1 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s1)
s2 = StructType(
[StructField("id", IntegerType(), True), StructField("name", StringType(), False)]
)
df2 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s2)
assertDataFrameEqual(df1, df2)
with self.assertRaises(PySparkAssertionError):
assertDataFrameEqual(df1, df2, ignoreNullable=False)
def test_schema_ignore_nullable_array_equal(self):
s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), False), False)])
assertSchemaEqual(s1, s2)
def test_schema_ignore_nullable_struct_equal(self):
s1 = StructType(
[StructField("names", StructType([StructField("age", IntegerType(), True)]), True)]
)
s2 = StructType(
[StructField("names", StructType([StructField("age", IntegerType(), False)]), False)]
)
assertSchemaEqual(s1, s2)
def test_schema_array_unequal(self):
s1 = StructType([StructField("names", ArrayType(IntegerType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), False), False)])
generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines())
expected_error_msg = "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertSchemaEqual(s1, s2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_SCHEMA",
message_parameters={"error_msg": expected_error_msg},
)
def test_schema_struct_unequal(self):
s1 = StructType(
[StructField("names", StructType([StructField("age", DoubleType(), True)]), True)]
)
s2 = StructType(
[StructField("names", StructType([StructField("age", IntegerType(), True)]), True)]
)
generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines())
expected_error_msg = "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertSchemaEqual(s1, s2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_SCHEMA",
message_parameters={"error_msg": expected_error_msg},
)
def test_schema_more_nested_struct_unequal(self):
s1 = StructType(
[
StructField(
"name",
StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
),
),
]
)
s2 = StructType(
[
StructField(
"name",
StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", BooleanType(), True),
StructField("lastname", StringType(), True),
]
),
),
]
)
generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines())
expected_error_msg = "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertSchemaEqual(s1, s2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_SCHEMA",
message_parameters={"error_msg": expected_error_msg},
)
def test_schema_unsupported_type(self):
s1 = "names: int"
s2 = "names: int"
with self.assertRaises(PySparkAssertionError) as pe:
assertSchemaEqual(s1, s2)
self.check_error(
exception=pe.exception,
error_class="UNSUPPORTED_DATA_TYPE",
message_parameters={"data_type": type(s1)},
)
def test_spark_sql(self):
assertDataFrameEqual(self.spark.sql("select 1 + 2 AS x"), self.spark.sql("select 3 AS x"))
assertDataFrameEqual(
self.spark.sql("select 1 + 2 AS x"),
self.spark.sql("select 3 AS x"),
checkRowOrder=True,
)
def test_spark_sql_sort_rows(self):
df1 = self.spark.createDataFrame(
data=[
(1, 3000),
(2, 1000),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
(2, 1000),
(1, 3000),
],
schema=["id", "amount"],
)
df1.createOrReplaceTempView("df1")
df2.createOrReplaceTempView("df2")
assertDataFrameEqual(
self.spark.sql("select * from df1 order by amount"), self.spark.sql("select * from df2")
)
assertDataFrameEqual(
self.spark.sql("select * from df1 order by amount"),
self.spark.sql("select * from df2"),
checkRowOrder=True,
)
def test_empty_dataset(self):
df1 = self.spark.range(0, 10).limit(0)
df2 = self.spark.range(0, 10).limit(0)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_no_column(self):
df1 = self.spark.range(0, 10).drop("id")
df2 = self.spark.range(0, 10).drop("id")
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_empty_no_column(self):
df1 = self.spark.range(0, 10).drop("id").limit(0)
df2 = self.spark.range(0, 10).drop("id").limit(0)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_empty_expected_list(self):
df1 = self.spark.range(0, 5).drop("id")
df2 = [Row(), Row(), Row(), Row(), Row()]
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_no_column_expected_list(self):
df1 = self.spark.range(0, 10).limit(0)
df2 = []
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_empty_no_column_expected_list(self):
df1 = self.spark.range(0, 10).drop("id").limit(0)
df2 = []
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_special_vals(self):
df1 = self.spark.createDataFrame(
data=[
(1, float("nan")),
(2, float("inf")),
(2, float("-inf")),
],
schema=["id", "amount"],
)
df2 = self.spark.createDataFrame(
data=[
(1, float("nan")),
(2, float("inf")),
(2, float("-inf")),
],
schema=["id", "amount"],
)
assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
def test_df_list_row_equal(self):
df1 = self.spark.createDataFrame(
data=[
(1, 3000),
(2, 1000),
],
schema=["id", "amount"],
)
list_of_rows = [Row(1, 3000), Row(2, 1000)]
assertDataFrameEqual(df1, list_of_rows, checkRowOrder=False)
assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
def test_list_rows_equal(self):
list_of_rows1 = [Row(1, "abc", 5000), Row(2, "def", 1000)]
list_of_rows2 = [Row(1, "abc", 5000), Row(2, "def", 1000)]
assertDataFrameEqual(list_of_rows1, list_of_rows2, checkRowOrder=False)
assertDataFrameEqual(list_of_rows1, list_of_rows2, checkRowOrder=True)
def test_list_rows_unequal(self):
list_of_rows1 = [Row(1, "abc", 5000), Row(2, "def", 1000)]
list_of_rows2 = [Row(1, "abc", 5000), Row(2, "defg", 1000)]
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(list_of_rows1, list_of_rows2)):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
)
error_msg = "Results do not match: "
percent_diff = (1 / 2) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(list_of_rows1, list_of_rows2)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(list_of_rows1, list_of_rows2, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_list_row_unequal_schema(self):
df1 = self.spark.createDataFrame(
data=[
(1, 3000),
(2, 1000),
(3, 10),
],
schema=["id", "amount"],
)
list_of_rows = [Row(id=1, amount=300), Row(id=2, amount=100), Row(id=3, amount=10)]
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1, list_of_rows)):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
)
error_msg = "Results do not match: "
percent_diff = (2 / 3) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, list_of_rows)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_list_row_unequal_schema(self):
from pyspark.sql import Row
df1 = self.spark.createDataFrame(
data=[
(1, 3000),
(2, 1000),
],
schema=["id", "amount"],
)
list_of_rows = [Row(1, "3000"), Row(2, "1000")]
rows_str1 = ""
rows_str2 = ""
# count different rows
for r1, r2 in list(zip_longest(df1.collect(), list_of_rows)):
rows_str1 += str(r1) + "\n"
rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
)
error_msg = "Results do not match: "
percent_diff = (2 / 2) * 100
error_msg += "( %.5f %% )" % percent_diff
error_msg += "\n" + "\n".join(generated_diff)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, list_of_rows)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
with self.assertRaises(PySparkAssertionError) as pe:
assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
self.check_error(
exception=pe.exception,
error_class="DIFFERENT_ROWS",
message_parameters={"error_msg": error_msg},
)
def test_dataframe_include_diff_rows(self):
df1 = self.spark.createDataFrame(
[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], ["id", "amount"]
)
df2 = self.spark.createDataFrame(
[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], ["id", "amount"]
)
with self.assertRaises(PySparkAssertionError) as context:
assertDataFrameEqual(df1, df2, includeDiffRows=True)
# Extracting the differing rows data from the exception
error_data = context.exception.data
# Expected differences
expected_diff = [
(Row(id="1", amount=1000.0), Row(id="1", amount=1001.0)),
(Row(id="3", amount=2000.0), Row(id="3", amount=2003.0)),
]
self.assertEqual(error_data, expected_diff)
def test_dataframe_ignore_column_order(self):
df1 = self.spark.createDataFrame([Row(A=1, B=2), Row(A=3, B=4)])
df2 = self.spark.createDataFrame([Row(B=2, A=1), Row(B=4, A=3)])
with self.assertRaises(PySparkAssertionError):
assertDataFrameEqual(df1, df2, ignoreColumnOrder=False)
assertDataFrameEqual(df1, df2, ignoreColumnOrder=True)
def test_dataframe_ignore_column_name(self):
df1 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["X", "Y"])
with self.assertRaises(PySparkAssertionError):
assertDataFrameEqual(df1, df2, ignoreColumnName=False)
assertDataFrameEqual(df1, df2, ignoreColumnName=True)
def test_dataframe_ignore_column_type(self):
df1 = self.spark.createDataFrame([(1, "2"), (3, "4")], ["A", "B"])
df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
with self.assertRaises(PySparkAssertionError):
assertDataFrameEqual(df1, df2, ignoreColumnType=False)
assertDataFrameEqual(df1, df2, ignoreColumnType=True)
def test_dataframe_max_errors(self):
df1 = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "c"), (4, "d")], ["id", "value"])
df2 = self.spark.createDataFrame([(1, "a"), (2, "z"), (3, "x"), (4, "y")], ["id", "value"])
# We expect differences in rows 2, 3, and 4.
# Setting maxErrors to 2 will limit the reported errors.
maxErrors = 2
with self.assertRaises(PySparkAssertionError) as context:
assertDataFrameEqual(df1, df2, maxErrors=maxErrors)
# Check if the error message contains information about 2 mismatches only.
error_message = str(context.exception)
self.assertTrue("! Row" in error_message and error_message.count("! Row") == maxErrors * 2)
def test_dataframe_show_only_diff(self):
df1 = self.spark.createDataFrame(
[(1, "apple", "red"), (2, "banana", "yellow"), (3, "cherry", "red")],
["id", "fruit", "color"],
)
df2 = self.spark.createDataFrame(
[(1, "apple", "green"), (2, "banana", "yellow"), (3, "cherry", "blue")],
["id", "fruit", "color"],
)
with self.assertRaises(PySparkAssertionError) as context:
assertDataFrameEqual(df1, df2, showOnlyDiff=False)
error_message = str(context.exception)
self.assertTrue("apple" in error_message and "banana" in error_message)
with self.assertRaises(PySparkAssertionError) as context:
assertDataFrameEqual(df1, df2, showOnlyDiff=True)
error_message = str(context.exception)
self.assertTrue("apple" in error_message and "banana" not in error_message)
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b").collect())
def test_capture_user_friendly_exception(self):
try:
self.spark.sql("select `中文字段`")
except AnalysisException as e:
self.assertRegex(str(e), ".*UNRESOLVED_COLUMN.*`中文字段`.*")
def test_spark_upgrade_exception(self):
# SPARK-32161 : Test case to Handle SparkUpgradeException in pythonic way
df = self.spark.createDataFrame([("2014-31-12",)], ["date_str"])
df2 = df.select(
"date_str", to_date(from_unixtime(unix_timestamp("date_str", "yyyy-dd-aa")))
)
self.assertRaises(SparkUpgradeException, df2.collect)
def test_capture_parse_exception(self):
self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegex(
IllegalArgumentException,
"Setting negative mapred.reduce.tasks",
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"),
)
def test_capture_pyspark_value_exception(self):
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
self.assertRaisesRegex(
PySparkValueError,
"Value for `numBits` has to be amongst the following values",
lambda: df.select(F.sha2(df.a, 1024)).collect(),
)
def test_get_error_class_state(self):
# SPARK-36953: test CapturedException.getErrorClass and getSqlState (from SparkThrowable)
exception = None
try:
self.spark.sql("""SELECT a""")
except AnalysisException as e:
exception = e
self.assertIsNotNone(exception)
self.assertEqual(exception.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
self.assertEqual(exception.getSqlState(), "42703")
self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"})
self.assertIn(
(
"[UNRESOLVED_COLUMN.WITHOUT_SUGGESTION] A column, variable, or function "
"parameter with name `a` cannot be resolved. SQLSTATE: 42703"
),
exception.getMessage(),
)
self.assertEqual(len(exception.getQueryContext()), 1)
qc = exception.getQueryContext()[0]
self.assertEqual(qc.fragment(), "a")
self.assertEqual(qc.stopIndex(), 7)
self.assertEqual(qc.startIndex(), 7)
self.assertEqual(qc.contextType(), QueryContextType.SQL)
self.assertEqual(qc.objectName(), "")
self.assertEqual(qc.objectType(), "")
try:
self.spark.sql("""SELECT assert_true(FALSE)""")
except AnalysisException as e:
self.assertIsNone(e.getErrorClass())
self.assertIsNone(e.getSqlState())
self.assertEqual(e.getMessageParameters(), {})
self.assertEqual(e.getMessage(), "")
def test_assert_data_frame_equal_not_support_streaming(self):
df1 = self.spark.readStream.format("rate").load()
df2 = self.spark.readStream.format("rate").load()
exception_thrown = False
try:
assertDataFrameEqual(df1, df2)
except PySparkAssertionError as e:
self.assertEqual(e.getErrorClass(), "UNSUPPORTED_OPERATION")
exception_thrown = True
self.assertTrue(exception_thrown)
class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
pass
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_utils import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)