blob: 47963e89471c2e90624fca069669ea0ca4bc0dc4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.errors import PySparkAttributeError
from pyspark.errors.exceptions.base import SessionNotSameException
from pyspark.sql.types import Row
from pyspark.sql import functions as F
from pyspark.errors import PySparkTypeError
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.util import is_remote_only
class SparkConnectErrorTests(ReusedConnectTestCase):
def test_recursion_handling_for_plan_logging(self):
"""SPARK-45852 - Test that we can handle recursion in plan logging."""
cdf = self.spark.range(1)
for x in range(400):
cdf = cdf.withColumn(f"col_{x}", F.lit(x))
# Calling schema will trigger logging the message that will in turn trigger the message
# conversion into protobuf that will then trigger the recursion error.
self.assertIsNotNone(cdf.schema)
result = self.spark._client._proto_to_string(cdf._plan.to_proto(self.spark._client))
self.assertIn("recursion", result)
def test_error_handling(self):
from pyspark.errors.exceptions.connect import AnalysisException
# SPARK-41533 Proper error handling for Spark Connect
df = self.spark.range(10).select("id2")
with self.assertRaises(AnalysisException):
df.collect()
def test_invalid_column(self):
from pyspark.errors.exceptions.connect import AnalysisException
# SPARK-41812: fail df1.select(df2.col)
data1 = [Row(a=1, b=2, c=3)]
cdf1 = self.spark.createDataFrame(data1)
data2 = [Row(a=2, b=0)]
cdf2 = self.spark.createDataFrame(data2)
with self.assertRaises(AnalysisException):
cdf1.select(cdf2.a).schema
with self.assertRaises(AnalysisException):
cdf2.withColumn("x", cdf1.a + 1).schema
# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
AnalysisException,
"UNRESOLVED_COLUMN.WITH_SUGGESTION",
):
cdf3 = cdf1.select(cdf1.a)
cdf3.select(cdf1.b).schema
# Can not find the target plan node by plan id
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
cdf1.select(cdf2.a).schema
def test_invalid_star(self):
from pyspark.errors.exceptions.connect import AnalysisException
data1 = [Row(a=1, b=2, c=3)]
cdf1 = self.spark.createDataFrame(data1)
data2 = [Row(a=2, b=0)]
cdf2 = self.spark.createDataFrame(data2)
# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
cdf3 = cdf1.select(cdf1.a)
cdf3.select(cdf1["*"]).schema
# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
# column 'a has been replaced
cdf3 = cdf1.withColumn("a", F.lit(0))
cdf3.select(cdf1["*"]).schema
# Can not find the target plan node by plan id
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
cdf1.select(cdf2["*"]).schema
# cdf1["*"] exists on both side
with self.assertRaisesRegex(
AnalysisException,
"AMBIGUOUS_COLUMN_REFERENCE",
):
cdf1.join(cdf1).select(cdf1["*"]).schema
def test_deduplicate_within_watermark_in_batch(self):
from pyspark.errors.exceptions.connect import AnalysisException
table_name = "tmp_table_for_test_deduplicate_within_watermark_in_batch"
with self.table(table_name):
self.spark.createDataFrame(
[Row(key=i, value=str(i)) for i in range(100)]
).write.saveAsTable(table_name)
with self.assertRaisesRegex(
AnalysisException,
"dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
):
self.spark.read.table(table_name).dropDuplicatesWithinWatermark().toPandas()
def test_different_spark_session_join_or_union(self):
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
df = self.spark.range(10).limit(3)
spark2 = RemoteSparkSession(connection="sc://localhost")
df2 = spark2.range(10).limit(3)
with self.assertRaises(SessionNotSameException) as e1:
df.union(df2).collect()
self.check_error(
exception=e1.exception,
errorClass="SESSION_NOT_SAME",
messageParameters={},
)
with self.assertRaises(SessionNotSameException) as e2:
df.unionByName(df2).collect()
self.check_error(
exception=e2.exception,
errorClass="SESSION_NOT_SAME",
messageParameters={},
)
with self.assertRaises(SessionNotSameException) as e3:
df.join(df2).collect()
self.check_error(
exception=e3.exception,
errorClass="SESSION_NOT_SAME",
messageParameters={},
)
@unittest.skipIf(is_remote_only(), "Disabled for remote only")
def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
df = self.spark.range(10)
with self.assertRaises(NotImplementedError):
df.toJSON()
with self.assertRaises(NotImplementedError):
df.rdd
def test_unsupported_jvm_attribute(self):
# Unsupported jvm attributes for Spark session.
unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
spark_session = self.spark
for attr in unsupported_attrs:
with self.assertRaises(PySparkAttributeError) as pe:
getattr(spark_session, attr)
self.check_error(
exception=pe.exception,
errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED",
messageParameters={"attr_name": attr},
)
# Unsupported jvm attributes for DataFrame.
unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"]
cdf = self.spark.range(10)
for attr in unsupported_attrs:
with self.assertRaises(PySparkAttributeError) as pe:
getattr(cdf, attr)
self.check_error(
exception=pe.exception,
errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED",
messageParameters={"attr_name": attr},
)
# Unsupported jvm attributes for Column.
with self.assertRaises(PySparkAttributeError) as pe:
getattr(cdf.id, "_jc")
self.check_error(
exception=pe.exception,
errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED",
messageParameters={"attr_name": "_jc"},
)
# Unsupported jvm attributes for DataFrameReader.
with self.assertRaises(PySparkAttributeError) as pe:
getattr(spark_session.read, "_jreader")
self.check_error(
exception=pe.exception,
errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED",
messageParameters={"attr_name": "_jreader"},
)
def test_column_cannot_be_constructed_from_string(self):
from pyspark.sql.connect.column import Column
with self.assertRaises(TypeError):
Column("col")
def test_select_none(self):
with self.assertRaises(PySparkTypeError) as e1:
self.spark.range(1).select(None)
self.check_error(
exception=e1.exception,
errorClass="NOT_LIST_OF_COLUMN_OR_STR",
messageParameters={"arg_name": "columns"},
)
def test_ym_interval_in_collect(self):
# YearMonthIntervalType is not supported in python side arrow conversion
with self.assertRaises(PySparkTypeError):
self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval").first()
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_error import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)