blob: edd769680c7799f6609119e3022cf74171b39f6a [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.errors import (
AnalysisException,
ArithmeticException,
QueryContextType,
NumberFormatException,
)
from pyspark.sql import functions as sf
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
)
class DataFrameQueryContextTestsMixin:
def test_dataframe_query_context(self):
# SPARK-47274: Add more useful contexts for PySpark DataFrame API errors.
with self.sql_conf({"spark.sql.ansi.enabled": True}):
df = self.spark.range(10)
# DataFrameQueryContext with pysparkLoggingInfo - divide
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("div_zero", df.id / 0).collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
fragment="__truediv__",
)
# DataFrameQueryContext with pysparkLoggingInfo - plus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("plus_invalid_type", df.id + "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
)
# DataFrameQueryContext with pysparkLoggingInfo - minus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("minus_invalid_type", df.id - "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
)
# DataFrameQueryContext with pysparkLoggingInfo - multiply
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_invalid_type", df.id * "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
)
# DataFrameQueryContext with pysparkLoggingInfo - mod
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("mod_invalid_type", df.id % "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mod__",
)
# DataFrameQueryContext with pysparkLoggingInfo - equalTo
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("equalTo_invalid_type", df.id == "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__eq__",
)
# DataFrameQueryContext with pysparkLoggingInfo - lt
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("lt_invalid_type", df.id < "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__lt__",
)
# DataFrameQueryContext with pysparkLoggingInfo - leq
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("leq_invalid_type", df.id <= "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__le__",
)
# DataFrameQueryContext with pysparkLoggingInfo - geq
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("geq_invalid_type", df.id >= "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__ge__",
)
# DataFrameQueryContext with pysparkLoggingInfo - gt
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("gt_invalid_type", df.id > "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__gt__",
)
# DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("eqNullSafe_invalid_type", df.id.eqNullSafe("string")).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="eqNullSafe",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseOR_invalid_type", df.id.bitwiseOR("string")).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseOR",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseAND_invalid_type", df.id.bitwiseAND("string")).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseAND",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseXOR_invalid_type", df.id.bitwiseXOR("string")).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseXOR",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`divide` is problematic)
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_zero", df.id / 0
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
fragment="__truediv__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_string", df.id + "string").withColumn(
"minus_ten", df.id - 10
).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn(
"minus_string", df.id - "string"
).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`multiply` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_string", df.id * "string").withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
)
# Multiple expressions in df.select (`divide` is problematic)
with self.assertRaises(ArithmeticException) as pe:
df.select(df.id - 10, df.id + 4, df.id / 0, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
fragment="__truediv__",
)
# Multiple expressions in df.select (`plus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.select(df.id - 10, df.id + "string", df.id / 10, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
)
# Multiple expressions in df.select (`minus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.select(df.id - "string", df.id + 4, df.id / 10, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
)
# Multiple expressions in df.select (`multiply` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.select(df.id - 10, df.id + 4, df.id / 10, df.id * "string").collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
)
# Multiple expressions with pre-declared expressions (`divide` is problematic)
a = df.id / 10
b = df.id / 0
with self.assertRaises(ArithmeticException) as pe:
df.select(a, df.id + 4, b, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
fragment="__truediv__",
)
# Multiple expressions with pre-declared expressions (`plus` is problematic)
a = df.id + "string"
b = df.id + 4
with self.assertRaises(NumberFormatException) as pe:
df.select(df.id / 10, a, b, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
)
# Multiple expressions with pre-declared expressions (`minus` is problematic)
a = df.id - "string"
b = df.id - 5
with self.assertRaises(NumberFormatException) as pe:
df.select(a, df.id / 10, b, df.id * 5).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
)
# Multiple expressions with pre-declared expressions (`multiply` is problematic)
a = df.id * "string"
b = df.id * 10
with self.assertRaises(NumberFormatException) as pe:
df.select(a, df.id / 10, b, df.id + 5).collect()
self.check_error(
exception=pe.exception,
errorClass="CAST_INVALID_INPUT",
messageParameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
)
def test_sql_query_context(self):
with self.sql_conf({"spark.sql.ansi.enabled": True}):
# SQLQueryContext
with self.assertRaises(ArithmeticException) as pe:
self.spark.sql("select 10/0").collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.SQL,
)
# No QueryContext
with self.assertRaises(AnalysisException) as pe:
self.spark.sql("select * from non-existing-table")
self.check_error(
exception=pe.exception,
errorClass="INVALID_IDENTIFIER",
messageParameters={"ident": "non-existing-table"},
query_context_type=None,
)
def test_query_context_complex(self):
with self.sql_conf({"spark.sql.ansi.enabled": True}):
# SQLQueryContext
with self.assertRaises(ArithmeticException) as pe:
self.spark.sql("select (10/0)*100").collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.SQL,
)
# DataFrameQueryContext
df = self.spark.range(10)
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("div_zero", (df.id / 0) * 10).collect()
self.check_error(
exception=pe.exception,
errorClass="DIVIDE_BY_ZERO",
messageParameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
fragment="__truediv__",
)
def test_dataframe_query_context_col(self):
with self.assertRaises(AnalysisException) as pe:
self.spark.range(1).select(sf.col("id") + sf.col("idd")).show()
self.check_error(
exception=pe.exception,
errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
messageParameters={"objectName": "`idd`", "proposal": "`id`"},
query_context_type=QueryContextType.DataFrame,
fragment="col",
)
class DataFrameQueryContextTests(DataFrameQueryContextTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_dataframe_query_context import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)