blob: c5b16f0da49e9b3e5366b94df85eeedfd7ff2849 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import gc
import unittest
import shutil
import tempfile
import io
from contextlib import redirect_stdout
import datetime
from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.types import (
StructType,
StructField,
LongType,
StringType,
IntegerType,
MapType,
ArrayType,
Row,
)
from pyspark.testing.utils import eventually
from pyspark.testing.connectutils import (
should_test_connect,
connect_requirement_message,
ReusedMixedTestCase,
)
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
if should_test_connect:
from pyspark.sql.connect.proto import ExecutePlanResponse, Expression as ProtoExpression
from pyspark.sql.connect.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException
@unittest.skipIf(
not should_test_connect or is_remote_only(),
connect_requirement_message or "Requires JVM access",
)
class SparkConnectSQLTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils):
"""Parent test fixture class for all Spark Connect related
test cases."""
@classmethod
def setUpClass(cls):
super(SparkConnectSQLTestCase, cls).setUpClass()
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.testDataStr = [Row(key=str(i)) for i in range(100)]
cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
cls.tbl_name = "test_connect_basic_table_1"
cls.tbl_name2 = "test_connect_basic_table_2"
cls.tbl_name3 = "test_connect_basic_table_3"
cls.tbl_name4 = "test_connect_basic_table_4"
cls.tbl_name_empty = "test_connect_basic_table_empty"
# Cleanup test data
cls.spark_connect_clean_up_test_data()
# Load test data
cls.spark_connect_load_test_data()
@classmethod
def tearDownClass(cls):
try:
cls.spark_connect_clean_up_test_data()
finally:
super(SparkConnectSQLTestCase, cls).tearDownClass()
@classmethod
def spark_connect_load_test_data(cls):
df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
# Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
df2 = cls.spark.createDataFrame(
[(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]
)
df2.write.saveAsTable(cls.tbl_name2)
df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"])
df3.write.saveAsTable(cls.tbl_name3)
df4 = cls.spark.createDataFrame(
[(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"]
)
df4.write.saveAsTable(cls.tbl_name4)
empty_table_schema = StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
)
emptyRDD = cls.spark.sparkContext.emptyRDD()
empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
empty_df.write.saveAsTable(cls.tbl_name_empty)
@classmethod
def spark_connect_clean_up_test_data(cls):
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name3))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name4))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
class SparkConnectBasicTests(SparkConnectSQLTestCase):
def test_serialization(self):
from pyspark.cloudpickle import dumps, loads
cdf = self.connect.range(10)
data = dumps(cdf)
cdf2 = loads(data)
self.assertEqual(cdf.collect(), cdf2.collect())
def test_window_spec_serialization(self):
from pyspark.sql.connect.window import Window
from pyspark.serializers import CPickleSerializer
pickle_ser = CPickleSerializer()
w = Window.partitionBy("some_string").orderBy("value")
b = pickle_ser.dumps(w)
w2 = pickle_ser.loads(b)
self.assertEqual(str(w), str(w2))
def test_df_getattr_behavior(self):
cdf = self.connect.range(10)
sdf = self.spark.range(10)
sdf._simple_extension = 10
cdf._simple_extension = 10
self.assertEqual(sdf._simple_extension, cdf._simple_extension)
self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))
self.assertTrue(hasattr(cdf, "_simple_extension"))
self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))
def test_df_get_item(self):
# SPARK-41779: test __getitem__
query = """
SELECT * FROM VALUES
(true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
AS tab(a, b, c)
"""
# +-----+----+----+
# | a| b| c|
# +-----+----+----+
# | true| 1|NULL|
# |false|NULL| 2.0|
# | NULL| 3| 3.0|
# +-----+----+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# filter
self.assert_eq(
cdf[cdf.a].toPandas(),
sdf[sdf.a].toPandas(),
)
self.assert_eq(
cdf[cdf.b.isin(2, 3)].toPandas(),
sdf[sdf.b.isin(2, 3)].toPandas(),
)
self.assert_eq(
cdf[cdf.c > 1.5].toPandas(),
sdf[sdf.c > 1.5].toPandas(),
)
# select
self.assert_eq(
cdf[[cdf.a, "b", cdf.c]].toPandas(),
sdf[[sdf.a, "b", sdf.c]].toPandas(),
)
self.assert_eq(
cdf[(cdf.a, "b", cdf.c)].toPandas(),
sdf[(sdf.a, "b", sdf.c)].toPandas(),
)
# select by index
self.assertTrue(isinstance(cdf[0], Column))
self.assertTrue(isinstance(cdf[1], Column))
self.assertTrue(isinstance(cdf[2], Column))
self.assert_eq(
cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
)
# check error
with self.assertRaises(PySparkTypeError) as pe:
cdf[1.5]
self.check_error(
exception=pe.exception,
errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
messageParameters={
"arg_name": "item",
"arg_type": "float",
},
)
with self.assertRaises(PySparkTypeError) as pe:
cdf[None]
self.check_error(
exception=pe.exception,
errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
messageParameters={
"arg_name": "item",
"arg_type": "NoneType",
},
)
with self.assertRaises(PySparkTypeError) as pe:
cdf[cdf]
self.check_error(
exception=pe.exception,
errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
messageParameters={
"arg_name": "item",
"arg_type": "DataFrame",
},
)
def test_join_condition_column_list_columns(self):
left_connect_df = self.connect.read.table(self.tbl_name)
right_connect_df = self.connect.read.table(self.tbl_name2)
left_spark_df = self.spark.read.table(self.tbl_name)
right_spark_df = self.spark.read.table(self.tbl_name2)
joined_plan = left_connect_df.join(
other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner"
)
joined_plan2 = left_spark_df.join(
other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner"
)
self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas())
joined_plan3 = left_connect_df.join(
other=right_connect_df,
on=[
left_connect_df.id == right_connect_df.col1,
left_connect_df.name == right_connect_df.col2,
],
how="inner",
)
joined_plan4 = left_spark_df.join(
other=right_spark_df,
on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2],
how="inner",
)
self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
def test_join_ambiguous_cols(self):
# SPARK-41812: test join with ambiguous columns
data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
cdf1 = self.connect.createDataFrame(data1)
sdf1 = self.spark.createDataFrame(data1)
data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
cdf2 = self.connect.createDataFrame(data2)
sdf2 = self.spark.createDataFrame(data2)
cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
self.assertEqual(cdf3.schema, sdf3.schema)
self.assertEqual(cdf3.collect(), sdf3.collect())
cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
self.assertEqual(cdf4.schema, sdf4.schema)
self.assertEqual(cdf4.collect(), sdf4.collect())
cdf5 = cdf1.join(
cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
)
sdf5 = sdf1.join(
sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
)
self.assertEqual(cdf5.schema, sdf5.schema)
self.assertEqual(cdf5.collect(), sdf5.collect())
cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
self.assertEqual(cdf6.schema, sdf6.schema)
self.assertEqual(cdf6.collect(), sdf6.collect())
cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
self.assertEqual(cdf7.schema, sdf7.schema)
self.assertEqual(cdf7.collect(), sdf7.collect())
def test_join_with_cte(self):
cte_query = "with dt as (select 1 as ida) select ida as id from dt"
sdf1 = self.spark.range(10)
sdf2 = self.spark.sql(cte_query)
sdf3 = sdf1.join(sdf2, sdf1.id == sdf2.id)
cdf1 = self.connect.range(10)
cdf2 = self.connect.sql(cte_query)
cdf3 = cdf1.join(cdf2, cdf1.id == cdf2.id)
self.assertEqual(sdf3.schema, cdf3.schema)
self.assertEqual(sdf3.collect(), cdf3.collect())
def test_with_columns_renamed(self):
# SPARK-41312: test DataFrame.withColumnsRenamed()
self.assertEqual(
self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema,
)
self.assertEqual(
self.connect.read.table(self.tbl_name)
.withColumnsRenamed({"id": "id_new", "name": "name_new"})
.schema,
self.spark.read.table(self.tbl_name)
.withColumnsRenamed({"id": "id_new", "name": "name_new"})
.schema,
)
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
self.assertGreater(len(result), 0)
def _check_print_schema(self, query: str):
with io.StringIO() as buf, redirect_stdout(buf):
self.spark.sql(query).printSchema()
print1 = buf.getvalue()
with io.StringIO() as buf, redirect_stdout(buf):
self.connect.sql(query).printSchema()
print2 = buf.getvalue()
self.assertEqual(print1, print2, query)
for level in [-1, 0, 1, 2, 3, 4]:
with io.StringIO() as buf, redirect_stdout(buf):
self.spark.sql(query).printSchema(level)
print1 = buf.getvalue()
with io.StringIO() as buf, redirect_stdout(buf):
self.connect.sql(query).printSchema(level)
print2 = buf.getvalue()
self.assertEqual(print1, print2, query)
def test_schema(self):
schema = self.connect.read.table(self.tbl_name).schema
self.assertEqual(
StructType(
[StructField("id", LongType(), True), StructField("name", StringType(), True)]
),
schema,
)
# test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
query = """
SELECT * FROM VALUES
(float(1.0), double(1.0), 1.0, "1", true, NULL),
(float(2.0), double(2.0), 2.0, "2", false, NULL),
(float(3.0), double(3.0), NULL, "3", false, NULL)
AS tab(a, b, c, d, e, f)
"""
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
# test TimestampType, DateType
query = """
SELECT * FROM VALUES
(TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
(TIMESTAMP('2019-04-12 15:50:00'), NULL),
(NULL, DATE('2022-02-22'))
AS tab(a, b)
"""
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
# test DayTimeIntervalType
query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
# test MapType
query = """
SELECT * FROM VALUES
(MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
(MAP('x', 'yz'), MAP('x', NULL), NULL),
(MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
AS tab(a, b, c)
"""
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
# test ArrayType
query = """
SELECT * FROM VALUES
(ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
(ARRAY('x', NULL), NULL, ARRAY(1, 3)),
(NULL, ARRAY(-1, -2, -3), Array())
AS tab(a, b, c)
"""
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
# test StructType
query = """
SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
(float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
(float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
(float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
AS tab(a, b, c, d, e, f, g, h)
"""
self.assertEqual(
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
self._check_print_schema(query)
def test_to(self):
# SPARK-41464: test DataFrame.to()
cdf = self.connect.read.table(self.tbl_name)
df = self.spark.read.table(self.tbl_name)
def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType):
cdf_to = cdf.to(schema)
df_to = df.to(schema)
self.assertEqual(cdf_to.schema, df_to.schema)
self.assert_eq(cdf_to.toPandas(), df_to.toPandas())
# The schema has not changed
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
]
)
assert_eq_schema(cdf, df, schema)
# Change schema with struct
schema2 = StructType([StructField("struct", schema, False)])
cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2)
df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2)
self.assertEqual(cdf_to.schema, df_to.schema)
# Change the column name
schema = StructType(
[
StructField("col1", IntegerType(), True),
StructField("col2", StringType(), True),
]
)
assert_eq_schema(cdf, df, schema)
# Change the column data type
schema = StructType(
[
StructField("id", StringType(), True),
StructField("name", StringType(), True),
]
)
assert_eq_schema(cdf, df, schema)
# Reduce the column quantity and change data type
schema = StructType(
[
StructField("id", LongType(), True),
]
)
assert_eq_schema(cdf, df, schema)
# incompatible field nullability
schema = StructType([StructField("id", LongType(), False)])
self.assertRaisesRegex(
AnalysisException,
"NULLABLE_COLUMN_OR_FIELD",
lambda: cdf.to(schema).toPandas(),
)
# field cannot upcast
schema = StructType([StructField("name", LongType())])
self.assertRaisesRegex(
AnalysisException,
"INVALID_COLUMN_OR_FIELD_DATA_TYPE",
lambda: cdf.to(schema).toPandas(),
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", IntegerType(), True),
]
)
self.assertRaisesRegex(
AnalysisException,
"INVALID_COLUMN_OR_FIELD_DATA_TYPE",
lambda: cdf.to(schema).toPandas(),
)
# Test map type and array type
schema = StructType(
[
StructField("id", StringType(), True),
StructField("my_map", MapType(StringType(), IntegerType(), False), True),
StructField("my_array", ArrayType(IntegerType(), False), True),
]
)
cdf = self.connect.read.table(self.tbl_name4)
df = self.spark.read.table(self.tbl_name4)
assert_eq_schema(cdf, df, schema)
def test_toDF(self):
# SPARK-41310: test DataFrame.toDF()
self.assertEqual(
self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema,
self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema,
)
def test_print_schema(self):
# SPARK-41216: Test print schema
tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y").schema.treeString()
# root
# |-- X: integer (nullable = false)
# |-- Y: integer (nullable = false)
expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
self.assertEqual(tree_str, expected)
def test_is_local(self):
# SPARK-41216: Test is local
self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal())
self.assertFalse(self.connect.read.table(self.tbl_name).isLocal())
def test_is_streaming(self):
# SPARK-41216: Test is streaming
self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
def test_input_files(self):
# SPARK-41216: Test input files
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
self.df_text.write.text(tmpPath)
input_files_list1 = (
self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
)
input_files_list2 = (
self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
)
self.assertTrue(len(input_files_list1) > 0)
self.assertEqual(len(input_files_list1), len(input_files_list2))
for file_path in input_files_list2:
self.assertTrue(file_path in input_files_list1)
finally:
shutil.rmtree(tmpPath)
def test_limit_offset(self):
df = self.connect.read.table(self.tbl_name)
pd = df.limit(10).offset(1).toPandas()
self.assertEqual(9, len(pd.index))
pd2 = df.offset(98).limit(10).toPandas()
self.assertEqual(2, len(pd2.index))
def test_tail(self):
df = self.connect.read.table(self.tbl_name)
df2 = self.spark.read.table(self.tbl_name)
self.assertEqual(df.tail(10), df2.tail(10))
def test_sql(self):
pdf = self.connect.sql("SELECT 1").toPandas()
self.assertEqual(1, len(pdf.index))
def test_sql_with_named_args(self):
sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id > :minId"
df = self.connect.sql(
sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"), CF.lit(1))}
)
df2 = self.spark.sql(sqlText, args={"minId": 7, "m": SF.create_map(SF.lit("a"), SF.lit(1))})
self.assert_eq(df.toPandas(), df2.toPandas())
def test_namedargs_with_global_limit(self):
sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val)
where val = :val"""
df = self.connect.sql(sqlText, args={"val": 1})
df2 = self.spark.sql(sqlText, args={"val": 1})
self.assert_eq(df.toPandas(), df2.toPandas())
self.assert_eq(df.first()[0], datetime.datetime(2022, 12, 25, 10, 30))
self.assert_eq(df.first().date, datetime.datetime(2022, 12, 25, 10, 30))
self.assert_eq(df.first()[1], 1)
self.assert_eq(df.first().val, 1)
def test_sql_with_pos_args(self):
sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?"
df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7])
df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7])
self.assert_eq(df.toPandas(), df2.toPandas())
def test_sql_with_invalid_args(self):
sqlText = "SELECT ?, ?, ?"
for session in [self.connect, self.spark]:
with self.assertRaises(PySparkTypeError) as pe:
session.sql(sqlText, args={1, 2, 3})
self.check_error(
exception=pe.exception,
errorClass="INVALID_TYPE",
messageParameters={"arg_name": "args", "arg_type": "set"},
)
def test_deduplicate(self):
# SPARK-41326: test distinct and dropDuplicates.
df = self.connect.read.table(self.tbl_name)
df2 = self.spark.read.table(self.tbl_name)
self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas())
self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas())
self.assert_eq(
df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()
)
def test_drop(self):
# SPARK-41169: test drop
query = """
SELECT * FROM VALUES
(false, 1, NULL), (false, NULL, 2), (NULL, 3, 3)
AS tab(a, b, c)
"""
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.drop("a").toPandas(),
sdf.drop("a").toPandas(),
)
self.assert_eq(
cdf.drop("a", "b").toPandas(),
sdf.drop("a", "b").toPandas(),
)
self.assert_eq(
cdf.drop("a", "x").toPandas(),
sdf.drop("a", "x").toPandas(),
)
self.assert_eq(
cdf.drop(cdf.a, "x").toPandas(),
sdf.drop(sdf.a, "x").toPandas(),
)
def test_subquery_alias(self) -> None:
# SPARK-40938: test subquery alias.
plan_text = (
self.connect.read.table(self.tbl_name)
.alias("special_alias")
._explain_string(extended=True)
)
self.assertTrue("special_alias" in plan_text)
def test_sort(self):
# SPARK-41332: test sort
query = """
SELECT * FROM VALUES
(false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
AS tab(a, b, c)
"""
# +-----+----+----+
# | a| b| c|
# +-----+----+----+
# |false| 1|NULL|
# |false|NULL| 2.0|
# | NULL| 3| 3.0|
# +-----+----+----+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assert_eq(
cdf.sort("a").toPandas(),
sdf.sort("a").toPandas(),
)
self.assert_eq(
cdf.sort("c").toPandas(),
sdf.sort("c").toPandas(),
)
self.assert_eq(
cdf.sort("b").toPandas(),
sdf.sort("b").toPandas(),
)
self.assert_eq(
cdf.sort(cdf.c, "b").toPandas(),
sdf.sort(sdf.c, "b").toPandas(),
)
self.assert_eq(
cdf.sort(cdf.c.desc(), "b").toPandas(),
sdf.sort(sdf.c.desc(), "b").toPandas(),
)
self.assert_eq(
cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(),
sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(),
)
def test_range(self):
self.assert_eq(
self.connect.range(start=0, end=10).toPandas(),
self.spark.range(start=0, end=10).toPandas(),
)
self.assert_eq(
self.connect.range(start=0, end=10, step=3).toPandas(),
self.spark.range(start=0, end=10, step=3).toPandas(),
)
self.assert_eq(
self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(),
)
# SPARK-41301
self.assert_eq(
self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()
)
def test_create_global_temp_view(self):
# SPARK-41127: test global temp view creation.
with self.tempView("view_1"):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1")
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
# Test when creating a view which is already exists but
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
with self.assertRaises(AnalysisException):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
def test_create_session_local_temp_view(self):
# SPARK-41372: test session local temp view creation.
with self.tempView("view_local_temp"):
self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp")
self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1)
self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp")
self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0)
# Test when creating a view which is already exists but
with self.assertRaises(AnalysisException):
self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp")
def test_select_expr(self):
# SPARK-41201: test selectExpr API.
self.assert_eq(
self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name)
.selectExpr(["id * 2", "cast(name as long) as name"])
.toPandas(),
self.spark.read.table(self.tbl_name)
.selectExpr(["id * 2", "cast(name as long) as name"])
.toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name)
.selectExpr("id * 2", "cast(name as long) as name")
.toPandas(),
self.spark.read.table(self.tbl_name)
.selectExpr("id * 2", "cast(name as long) as name")
.toPandas(),
)
def test_select_star(self):
data = [Row(a=1, b=Row(c=2, d=Row(e=3)))]
# +---+--------+
# | a| b|
# +---+--------+
# | 1|{2, {3}}|
# +---+--------+
cdf = self.connect.createDataFrame(data=data)
sdf = self.spark.createDataFrame(data=data)
self.assertEqual(
cdf.select("*").collect(),
sdf.select("*").collect(),
)
self.assertEqual(
cdf.select("a", "*").collect(),
sdf.select("a", "*").collect(),
)
self.assertEqual(
cdf.select("a", "b").collect(),
sdf.select("a", "b").collect(),
)
self.assertEqual(
cdf.select("a", "b.*").collect(),
sdf.select("a", "b.*").collect(),
)
def test_union_by_name(self):
# SPARK-41832: Test unionByName
data1 = [(1, 2, 3)]
data2 = [(6, 2, 5)]
df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"])
df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"])
union_df_connect = df1_connect.unionByName(df2_connect)
df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"])
df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"])
union_df_spark = df1_spark.unionByName(df2_spark)
self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"])
union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True)
df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"])
union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True)
self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas())
def test_observe(self):
# SPARK-41527: test DataFrame.observe()
observation_name = "my_metric"
self.assert_eq(
self.connect.read.table(self.tbl_name)
.filter("id > 3")
.observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
.toPandas(),
self.spark.read.table(self.tbl_name)
.filter("id > 3")
.observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
.toPandas(),
)
from pyspark.sql.connect.observation import Observation as ConnectObservation
from pyspark.sql.observation import Observation
cobservation = ConnectObservation(observation_name)
observation = Observation(observation_name)
cdf = (
self.connect.read.table(self.tbl_name)
.filter("id > 3")
.observe(cobservation, CF.min("id"), CF.max("id"), CF.sum("id"))
.toPandas()
)
df = (
self.spark.read.table(self.tbl_name)
.filter("id > 3")
.observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
.toPandas()
)
self.assert_eq(cdf, df)
self.assertEqual(cobservation.get, observation.get)
observed_metrics = cdf.attrs["observed_metrics"]
self.assert_eq(len(observed_metrics), 1)
self.assert_eq(observed_metrics[0].name, observation_name)
self.assert_eq(len(observed_metrics[0].metrics), 3)
for metric in observed_metrics[0].metrics:
self.assertIsInstance(metric, ProtoExpression.Literal)
values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
self.assert_eq(values, [4, 99, 4944])
with self.assertRaises(PySparkValueError) as pe:
self.connect.read.table(self.tbl_name).observe(observation_name)
self.check_error(
exception=pe.exception,
errorClass="CANNOT_BE_EMPTY",
messageParameters={"item": "exprs"},
)
with self.assertRaises(PySparkTypeError) as pe:
self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
self.check_error(
exception=pe.exception,
errorClass="NOT_LIST_OF_COLUMN",
messageParameters={"arg_name": "exprs"},
)
def test_with_columns(self):
# SPARK-41256: test withColumn(s).
self.assert_eq(
self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(),
self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name)
.withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)})
.toPandas(),
self.spark.read.table(self.tbl_name)
.withColumns(
{
"id": SF.lit(False),
"col_not_exist": SF.lit(False),
}
)
.toPandas(),
)
def test_hint(self):
# SPARK-41349: Test hint
self.assert_eq(
self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(),
)
# Hint with unsupported name will be ignored
self.assert_eq(
self.connect.read.table(self.tbl_name).hint("illegal").toPandas(),
self.spark.read.table(self.tbl_name).hint("illegal").toPandas(),
)
# Hint with all supported parameter values
such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
self.assert_eq(
self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(),
)
# Hint with unsupported parameter values
with self.assertRaises(AnalysisException):
self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas()
# Hint with unsupported parameter types
with self.assertRaises(TypeError):
self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas()
# Hint with unsupported parameter types
with self.assertRaises(TypeError):
self.connect.read.table(self.tbl_name).hint(
"my awesome hint", 1.2345, 2, such_a_nice_list, range(6)
).toPandas()
# Hint with wrong combination
with self.assertRaises(AnalysisException):
self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas()
def test_join_hint(self):
cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
cdf2 = self.connect.createDataFrame(
[Row(height=80, name="Tom"), Row(height=85, name="Bob")]
)
self.assertTrue(
"BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()
)
self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string())
self.assertTrue(
"ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()
)
def test_extended_hint_types(self):
cdf = self.connect.range(100).toDF("id")
cdf.hint(
"my awesome hint",
1.2345,
"what",
["itworks1", "itworks2", "itworks3"],
).show()
with self.assertRaises(PySparkTypeError) as pe:
cdf.hint(
"my awesome hint",
1.2345,
"what",
{"itworks1": "itworks2"},
).show()
self.check_error(
exception=pe.exception,
errorClass="INVALID_ITEM_FOR_CONTAINER",
messageParameters={
"arg_name": "parameters",
"allowed_types": "str, float, int, Column, list[str], list[float], list[int]",
"item_type": "dict",
},
)
def test_empty_dataset(self):
# SPARK-41005: Test arrow based collection with empty dataset.
self.assertTrue(
self.connect.sql("SELECT 1 AS X LIMIT 0")
.toPandas()
.equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas())
)
pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas()
self.assertEqual(0, len(pdf)) # empty dataset
self.assertEqual(1, len(pdf.columns)) # one column
self.assertEqual("X", pdf.columns[0])
def test_is_empty(self):
# SPARK-41212: Test is empty
self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty())
self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty())
def test_is_empty_with_unsupported_types(self):
df = self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval")
self.assertEqual(df.count(), 1)
self.assertFalse(df.isEmpty())
def test_session(self):
self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession)
def test_show(self):
# SPARK-41111: Test the show method
show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string()
# +---+---+
# | X| Y|
# +---+---+
# | 1| 2|
# +---+---+
expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n"
self.assertEqual(show_str, expected)
def test_repr(self):
# SPARK-41213: Test the __repr__ method
query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
self.assertEqual(
self.connect.sql(query).__repr__(),
self.spark.sql(query).__repr__(),
)
def test_explain_string(self):
# SPARK-41122: test explain API.
plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True)
self.assertTrue("Parsed Logical Plan" in plan_str)
self.assertTrue("Analyzed Logical Plan" in plan_str)
self.assertTrue("Optimized Logical Plan" in plan_str)
self.assertTrue("Physical Plan" in plan_str)
with self.assertRaises(PySparkValueError) as pe:
self.connect.sql("SELECT 1")._explain_string(mode="unknown")
self.check_error(
exception=pe.exception,
errorClass="UNKNOWN_EXPLAIN_MODE",
messageParameters={"explain_mode": "unknown"},
)
def test_count(self) -> None:
# SPARK-41308: test count() API.
self.assertEqual(
self.connect.read.table(self.tbl_name).count(),
self.spark.read.table(self.tbl_name).count(),
)
def test_simple_transform(self) -> None:
"""SPARK-41203: Support DF.transform"""
def transform_df(input_df: CDataFrame) -> CDataFrame:
return input_df.select((CF.col("id") + CF.lit(10)).alias("id"))
df = self.connect.range(1, 100)
result_left = df.transform(transform_df).collect()
result_right = self.connect.range(11, 110).collect()
self.assertEqual(result_right, result_left)
# Check assertion.
with self.assertRaises(AssertionError):
df.transform(lambda x: 2) # type: ignore
def test_alias(self) -> None:
"""Testing supported and unsupported alias"""
col0 = (
self.connect.range(1, 10)
.select(CF.col("id").alias("name", metadata={"max": 99}))
.schema.names[0]
)
self.assertEqual("name", col0)
with self.assertRaises(SparkConnectException) as exc:
self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect()
self.assertIn("(this, is, not)", str(exc.exception))
def test_column_regexp(self) -> None:
# SPARK-41438: test dataframe.colRegex()
ndf = self.connect.read.table(self.tbl_name3)
df = self.spark.read.table(self.tbl_name3)
self.assert_eq(
ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(),
df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(),
)
def test_repartition(self) -> None:
# SPARK-41354: test dataframe.repartition(numPartitions)
self.assert_eq(
self.connect.read.table(self.tbl_name).repartition(10).toPandas(),
self.spark.read.table(self.tbl_name).repartition(10).toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name).coalesce(10).toPandas(),
self.spark.read.table(self.tbl_name).coalesce(10).toPandas(),
)
def test_repartition_by_expression(self) -> None:
# SPARK-41354: test dataframe.repartition(expressions)
self.assert_eq(
self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(),
self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name).repartition("id").toPandas(),
self.spark.read.table(self.tbl_name).repartition("id").toPandas(),
)
# repartition with unsupported parameter values
with self.assertRaises(AnalysisException):
self.connect.read.table(self.tbl_name).repartition("id+1").toPandas()
def test_repartition_by_range(self) -> None:
# SPARK-41354: test dataframe.repartitionByRange(expressions)
cdf = self.connect.read.table(self.tbl_name)
sdf = self.spark.read.table(self.tbl_name)
self.assert_eq(
cdf.repartitionByRange(10, "id").toPandas(),
sdf.repartitionByRange(10, "id").toPandas(),
)
self.assert_eq(
cdf.repartitionByRange("id").toPandas(),
sdf.repartitionByRange("id").toPandas(),
)
self.assert_eq(
cdf.repartitionByRange(cdf.id.desc()).toPandas(),
sdf.repartitionByRange(sdf.id.desc()).toPandas(),
)
# repartitionByRange with unsupported parameter values
with self.assertRaises(AnalysisException):
self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas()
def test_crossjoin(self):
# SPARK-41227: Test CrossJoin
connect_df = self.connect.read.table(self.tbl_name)
spark_df = self.spark.read.table(self.tbl_name)
self.assert_eq(
set(
connect_df.select("id")
.join(other=connect_df.select("name"), how="cross")
.toPandas()
),
set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()),
)
self.assert_eq(
set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()),
set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()),
)
def test_self_join(self):
# SPARK-47713: this query fails in classic spark
df1 = self.connect.createDataFrame([(1, "a")], schema=["i", "j"])
df1_filter = df1.filter(df1.i > 0)
df2 = df1.join(df1_filter, df1.i == 1)
self.assertEqual(df2.count(), 1)
self.assertEqual(df2.columns, ["i", "j", "i", "j"])
self.assertEqual(list(df2.first()), [1, "a", 1, "a"])
def test_with_metadata(self):
cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"])
self.assertEqual(cdf.schema["age"].metadata, {})
self.assertEqual(cdf.schema["name"].metadata, {})
cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5})
self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5})
cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]})
self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]})
with self.assertRaises(PySparkTypeError) as pe:
cdf.withMetadata(columnName="name", metadata=["magic"])
self.check_error(
exception=pe.exception,
errorClass="NOT_DICT",
messageParameters={
"arg_name": "metadata",
"arg_type": "list",
},
)
def test_version(self):
self.assertEqual(
self.connect.version,
self.spark.version,
)
def test_same_semantics(self):
plan = self.connect.sql("SELECT 1")
other = self.connect.sql("SELECT 1")
self.assertTrue(plan.sameSemantics(other))
def test_semantic_hash(self):
plan = self.connect.sql("SELECT 1")
other = self.connect.sql("SELECT 1")
self.assertEqual(
plan.semanticHash(),
other.semanticHash(),
)
def test_sql_with_command(self):
# SPARK-42705: spark.sql should return values from the command.
self.assertEqual(
self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()
)
def test_df_caache(self):
df = self.connect.range(10)
df.cache()
self.assert_eq(10, df.count())
self.assertTrue(df.is_cached)
def test_parse_col_name(self):
from pyspark.sql.connect.types import parse_attr_name
self.assert_eq(parse_attr_name(""), [""])
self.assert_eq(parse_attr_name("a"), ["a"])
self.assert_eq(parse_attr_name("`a`"), ["a"])
self.assert_eq(parse_attr_name("`a"), None)
self.assert_eq(parse_attr_name("a`"), None)
self.assert_eq(parse_attr_name("`a`.b"), ["a", "b"])
self.assert_eq(parse_attr_name("`a`.`b`"), ["a", "b"])
self.assert_eq(parse_attr_name("`a```.b"), ["a`", "b"])
self.assert_eq(parse_attr_name("`a``.b"), None)
self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"])
self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"])
self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"])
self.assert_eq(parse_attr_name("`a.b.c`"), ["a.b.c"])
self.assert_eq(parse_attr_name("a.`b.c`"), ["a", "b.c"])
self.assert_eq(parse_attr_name("`a.b`.c"), ["a.b", "c"])
self.assert_eq(parse_attr_name("`a.b.c"), None)
self.assert_eq(parse_attr_name("a.b.c`"), None)
self.assert_eq(parse_attr_name("`a.`b.`c"), None)
self.assert_eq(parse_attr_name("a`.b`.c`"), None)
self.assert_eq(parse_attr_name("`ab..c`e.f"), None)
def test_verify_col_name(self):
from pyspark.sql.connect.types import verify_col_name
cdf = (
self.connect.range(10)
.withColumn("v", CF.lit(123))
.withColumn("s", CF.struct("id", "v"))
.withColumn("m", CF.struct("s", "v"))
.withColumn("a", CF.array("s"))
)
# root
# |-- id: long (nullable = false)
# |-- v: integer (nullable = false)
# |-- s: struct (nullable = false)
# | |-- id: long (nullable = false)
# | |-- v: integer (nullable = false)
# |-- m: struct (nullable = false)
# | |-- s: struct (nullable = false)
# | | |-- id: long (nullable = false)
# | | |-- v: integer (nullable = false)
# | |-- v: integer (nullable = false)
# |-- a: array (nullable = false)
# | |-- element: struct (containsNull = false)
# | | |-- id: long (nullable = false)
# | | |-- v: integer (nullable = false)
self.assertTrue(verify_col_name("id", cdf.schema))
self.assertTrue(verify_col_name("`id`", cdf.schema))
self.assertTrue(verify_col_name("v", cdf.schema))
self.assertTrue(verify_col_name("`v`", cdf.schema))
self.assertFalse(verify_col_name("x", cdf.schema))
self.assertFalse(verify_col_name("`x`", cdf.schema))
self.assertTrue(verify_col_name("s", cdf.schema))
self.assertTrue(verify_col_name("`s`", cdf.schema))
self.assertTrue(verify_col_name("s.id", cdf.schema))
self.assertTrue(verify_col_name("s.`id`", cdf.schema))
self.assertTrue(verify_col_name("`s`.id", cdf.schema))
self.assertTrue(verify_col_name("`s`.`id`", cdf.schema))
self.assertFalse(verify_col_name("`s.id`", cdf.schema))
self.assertTrue(verify_col_name("m", cdf.schema))
self.assertTrue(verify_col_name("`m`", cdf.schema))
self.assertTrue(verify_col_name("m.s.id", cdf.schema))
self.assertTrue(verify_col_name("m.s.`id`", cdf.schema))
self.assertTrue(verify_col_name("m.`s`.id", cdf.schema))
self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema))
self.assertFalse(verify_col_name("m.`s.id`", cdf.schema))
self.assertTrue(verify_col_name("a", cdf.schema))
self.assertTrue(verify_col_name("`a`", cdf.schema))
self.assertTrue(verify_col_name("a.`v`", cdf.schema))
self.assertTrue(verify_col_name("a.`v`", cdf.schema))
self.assertTrue(verify_col_name("`a`.v", cdf.schema))
self.assertTrue(verify_col_name("`a`.`v`", cdf.schema))
self.assertFalse(verify_col_name("`a`.`x`", cdf.schema))
cdf = (
self.connect.range(10)
.withColumn("v", CF.lit(123))
.withColumn("s.s", CF.struct("id", "v"))
.withColumn("m`", CF.struct("`s.s`", "v"))
)
# root
# |-- id: long (nullable = false)
# |-- v: string (nullable = false)
# |-- s.s: struct (nullable = false)
# | |-- id: long (nullable = false)
# | |-- v: string (nullable = false)
# |-- m`: struct (nullable = false)
# | |-- s.s: struct (nullable = false)
# | | |-- id: long (nullable = false)
# | | |-- v: string (nullable = false)
# | |-- v: string (nullable = false)
self.assertFalse(verify_col_name("s", cdf.schema))
self.assertFalse(verify_col_name("`s`", cdf.schema))
self.assertFalse(verify_col_name("s.s", cdf.schema))
self.assertFalse(verify_col_name("s.`s`", cdf.schema))
self.assertFalse(verify_col_name("`s`.s", cdf.schema))
self.assertTrue(verify_col_name("`s.s`", cdf.schema))
self.assertFalse(verify_col_name("m", cdf.schema))
self.assertFalse(verify_col_name("`m`", cdf.schema))
self.assertTrue(verify_col_name("`m```", cdf.schema))
self.assertFalse(verify_col_name("`m```.s", cdf.schema))
self.assertFalse(verify_col_name("`m```.`s`", cdf.schema))
self.assertFalse(verify_col_name("`m```.s.s", cdf.schema))
self.assertFalse(verify_col_name("`m```.s.`s`", cdf.schema))
self.assertTrue(verify_col_name("`m```.`s.s`", cdf.schema))
self.assertFalse(verify_col_name("`m```.s.s.v", cdf.schema))
self.assertFalse(verify_col_name("`m```.s.`s`.v", cdf.schema))
self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema))
self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema))
def test_truncate_message(self):
cdf1 = self.connect.createDataFrame(
[
("a B c"),
("X y Z"),
],
["a" * 4096],
)
plan1 = cdf1._plan.to_proto(self.connect._client)
proto_string_1 = self.connect._client._proto_to_string(plan1, False)
self.assertTrue(len(proto_string_1) > 10000, len(proto_string_1))
proto_string_truncated_1 = self.connect._client._proto_to_string(plan1, True)
self.assertTrue(len(proto_string_truncated_1) < 4000, len(proto_string_truncated_1))
cdf2 = cdf1.select("a" * 4096, "a" * 4096, "a" * 4096)
plan2 = cdf2._plan.to_proto(self.connect._client)
proto_string_2 = self.connect._client._proto_to_string(plan2, False)
self.assertTrue(len(proto_string_2) > 20000, len(proto_string_2))
proto_string_truncated_2 = self.connect._client._proto_to_string(plan2, True)
self.assertTrue(len(proto_string_truncated_2) < 8000, len(proto_string_truncated_2))
cdf3 = cdf1.select("a" * 4096)
for _ in range(64):
cdf3 = cdf3.select("a" * 4096)
plan3 = cdf3._plan.to_proto(self.connect._client)
proto_string_3 = self.connect._client._proto_to_string(plan3, False)
self.assertTrue(len(proto_string_3) > 128000, len(proto_string_3))
proto_string_truncated_3 = self.connect._client._proto_to_string(plan3, True)
self.assertTrue(len(proto_string_truncated_3) < 64000, len(proto_string_truncated_3))
class SparkConnectGCTests(SparkConnectSQLTestCase):
@classmethod
def setUpClass(cls):
cls.origin = os.getenv("USER", None)
os.environ["USER"] = "SparkConnectGCTests"
super(SparkConnectGCTests, cls).setUpClass()
@classmethod
def tearDownClass(cls):
super(SparkConnectGCTests, cls).tearDownClass()
if cls.origin is not None:
os.environ["USER"] = cls.origin
else:
del os.environ["USER"]
def test_garbage_collection_checkpoint(self):
# SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state
# in Spark Connect server
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._plan._relation_id)
cached_remote_relation_id = df._plan._relation_id
jvm = self.spark._jvm
session_holder = getattr(
getattr(
jvm.org.apache.spark.sql.connect.service,
"SparkConnectService$",
),
"MODULE$",
).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id)
# Check the state exists.
self.assertIsNotNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)
del df
gc.collect()
def condition():
# Check the state was removed up on garbage-collection.
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)
eventually(catch_assertions=True)(condition)()
def test_garbage_collection_derived_checkpoint(self):
# SPARK-48258: Should keep the cached remote relation when derived DataFrames exist
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._plan._relation_id)
derived = df.repartition(10)
cached_remote_relation_id = df._plan._relation_id
jvm = self.spark._jvm
session_holder = getattr(
getattr(
jvm.org.apache.spark.sql.connect.service,
"SparkConnectService$",
),
"MODULE$",
).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id)
# Check the state exists.
self.assertIsNotNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)
del df
gc.collect()
def condition():
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)
# Should not remove the cache
with self.assertRaises(AssertionError):
eventually(catch_assertions=True, timeout=5)(condition)()
del derived
gc.collect()
eventually(catch_assertions=True)(condition)()
def test_arrow_batch_result_chunking(self):
# Two cases are tested here:
# (a) client preferred chunk size is set: the server should respect it
# (b) client preferred chunk size is not set: the server should use its own max chunk size
for preferred_chunk_size_optional, max_chunk_size_optional in ((1024, None), (None, 1024)):
sql_query = "select id, CAST(id + 0.5 AS DOUBLE) from range(0, 2000, 1, 4)"
cdf = self.connect.sql(sql_query)
sdf = self.spark.sql(sql_query)
original_verify_response_integrity = self.connect._client._verify_response_integrity
captured_chunks = []
def patched_verify_response_integrity(response):
original_verify_response_integrity(response)
if isinstance(response, ExecutePlanResponse) and response.HasField("arrow_batch"):
captured_chunks.append(response.arrow_batch)
try:
# Patch the response verifier for testing to access the chunked arrow batch
# responses.
self.connect._client._verify_response_integrity = patched_verify_response_integrity
# Override the chunk size to 1024 bytes for testing
if preferred_chunk_size_optional:
self.connect._client._preferred_arrow_chunk_size = preferred_chunk_size_optional
if max_chunk_size_optional:
self.connect.conf.set(
"spark.connect.session.resultChunking.maxChunkSize",
max_chunk_size_optional,
)
# Execute the query, and assert the results are correct.
self.assertEqual(cdf.collect(), sdf.collect())
# Verify the metadata of arrow batch chunks.
def split_into_batches(chunks):
batches = []
i = 0
n = len(chunks)
while i < n:
num_chunks = chunks[i].num_chunks_in_batch
batch = chunks[i : i + num_chunks]
batches.append(batch)
i += num_chunks
return batches
batches = split_into_batches(captured_chunks)
# There are 4 batches (partitions) in total.
self.assertEqual(len(batches), 4)
for batch in batches:
# In this example, the max chunk size is set to a small value, so each Arrow
# batch should be split into multiple chunks.
self.assertTrue(len(batch) > 5)
row_count = batch[0].row_count
row_start_offset = batch[0].start_offset
for i, chunk in enumerate(batch):
self.assertEqual(chunk.chunk_index, i)
self.assertEqual(chunk.num_chunks_in_batch, len(batch))
self.assertEqual(chunk.row_count, row_count)
self.assertEqual(chunk.start_offset, row_start_offset)
self.assertTrue(len(chunk.data) > 0)
self.assertTrue(
len(chunk.data)
<= (preferred_chunk_size_optional or max_chunk_size_optional)
)
finally:
self.connect._client._verify_response_integrity = original_verify_response_integrity
self.connect._client._preferred_arrow_chunk_size = None
self.connect.conf.unset("spark.connect.session.resultChunking.maxChunkSize")
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)