blob: 2343e8b9cde4c064efd9c78fe89e7e9cd9d02cdb [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
if should_test_connect:
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
class SparkConnectCollectionTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
def test_collect(self):
query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
data = cdf.limit(10).collect()
self.assertEqual(len(data), 10)
# Check Row has schema column names.
self.assertTrue("name" in data[0])
self.assertTrue("id" in data[0])
cdf = cdf.select(
CF.log("id"), CF.log("id"), CF.struct("id", "name"), CF.struct("id", "name")
).limit(10)
sdf = sdf.select(
SF.log("id"), SF.log("id"), SF.struct("id", "name"), SF.struct("id", "name")
).limit(10)
self.assertEqual(
cdf.collect(),
sdf.collect(),
)
def test_collect_timestamp(self):
query = """
SELECT * FROM VALUES
(TIMESTAMP('2022-12-25 10:30:00'), 1),
(TIMESTAMP('2022-12-25 10:31:00'), 2),
(TIMESTAMP('2022-12-25 10:32:00'), 1),
(TIMESTAMP('2022-12-25 10:33:00'), 2),
(TIMESTAMP('2022-12-26 09:30:00'), 1),
(TIMESTAMP('2022-12-26 09:35:00'), 3)
AS tab(date, val)
"""
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
self.assertEqual(
cdf.select(CF.date_trunc("year", cdf.date).alias("year")).collect(),
sdf.select(SF.date_trunc("year", sdf.date).alias("year")).collect(),
)
def test_head(self):
# SPARK-41002: test `head` API in Python Client
df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)")
self.assertIsNotNone(len(df.head()))
self.assertIsNotNone(len(df.head(1)))
self.assertIsNotNone(len(df.head(5)))
df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
self.assertIsNone(df2.head())
def test_first(self):
# SPARK-41002: test `first` API in Python Client
df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)")
self.assertIsNotNone(len(df.first()))
df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
self.assertIsNone(df2.first())
def test_take(self) -> None:
# SPARK-41002: test `take` API in Python Client
df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)")
self.assertEqual(5, len(df.take(5)))
df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
self.assertEqual(0, len(df2.take(5)))
def test_to_pandas(self):
# SPARK-41005: Test to pandas
query = """
SELECT * FROM VALUES
(false, 1, NULL),
(false, NULL, float(2.0)),
(NULL, 3, float(3.0))
AS tab(a, b, c)
"""
self.assert_eq(
self.connect.sql(query).toPandas(),
self.spark.sql(query).toPandas(),
)
query = """
SELECT * FROM VALUES
(1, 1, NULL),
(2, NULL, float(2.0)),
(3, 3, float(3.0))
AS tab(a, b, c)
"""
self.assert_eq(
self.connect.sql(query).toPandas(),
self.spark.sql(query).toPandas(),
)
query = """
SELECT * FROM VALUES
(double(1.0), 1, "1"),
(NULL, NULL, NULL),
(double(2.0), 3, "3")
AS tab(a, b, c)
"""
self.assert_eq(
self.connect.sql(query).toPandas(),
self.spark.sql(query).toPandas(),
)
query = """
SELECT * FROM VALUES
(float(1.0), double(1.0), 1, "1"),
(float(2.0), double(2.0), 2, "2"),
(float(3.0), double(3.0), 3, "3")
AS tab(a, b, c, d)
"""
self.assert_eq(
self.connect.sql(query).toPandas(),
self.spark.sql(query).toPandas(),
)
def test_collect_nested_type(self):
query = """
SELECT * FROM VALUES
(1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
(2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
(3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
AS tab(a, b, c, d, e, f, g, h)
"""
# +---+---+----+----+-----+----+------------+-------------------+
# | a| b| c| d| e| f| g| h|
# +---+---+----+----+-----+----+------------+-------------------+
# | 1| 4| 0| 8| true|true|[1, null, 3]| {1 -> 2, 3 -> 4}|
# | 2| 5| -1|NULL|false|NULL| [1, 3]|{1 -> null, 3 -> 4}|
# | 3| 6|NULL| 0|false|NULL| [null]| NULL|
# +---+---+----+----+-----+----+------------+-------------------+
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
# test collect array
# +--------------+-------------+------------+
# |array(a, b, c)| array(e, f)| g|
# +--------------+-------------+------------+
# | [1, 4, 0]| [true, true]|[1, null, 3]|
# | [2, 5, -1]|[false, null]| [1, 3]|
# | [3, 6, null]|[false, null]| [null]|
# +--------------+-------------+------------+
self.assertEqual(
cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"), CF.col("g")).collect(),
sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"), SF.col("g")).collect(),
)
# test collect nested array
# +-----------------------------------+-------------------------+
# |array(array(a), array(b), array(c))|array(array(e), array(f))|
# +-----------------------------------+-------------------------+
# | [[1], [4], [0]]| [[true], [true]]|
# | [[2], [5], [-1]]| [[false], [null]]|
# | [[3], [6], [null]]| [[false], [null]]|
# +-----------------------------------+-------------------------+
self.assertEqual(
cdf.select(
CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
CF.array(CF.array("e"), CF.array("f")),
).collect(),
sdf.select(
SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
SF.array(SF.array("e"), SF.array("f")),
).collect(),
)
# test collect array of struct, map
# +----------------+---------------------+
# |array(struct(a))| array(h)|
# +----------------+---------------------+
# | [{1}]| [{1 -> 2, 3 -> 4}]|
# | [{2}]|[{1 -> null, 3 -> 4}]|
# | [{3}]| [null]|
# +----------------+---------------------+
self.assertEqual(
cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
)
# test collect map
# +-------------------+-------------------+
# | h| map(a, b, b, c)|
# +-------------------+-------------------+
# | {1 -> 2, 3 -> 4}| {1 -> 4, 4 -> 0}|
# |{1 -> null, 3 -> 4}| {2 -> 5, 5 -> -1}|
# | NULL|{3 -> 6, 6 -> null}|
# +-------------------+-------------------+
self.assertEqual(
cdf.select(CF.col("h"), CF.create_map("a", "b", "b", "c")).collect(),
sdf.select(SF.col("h"), SF.create_map("a", "b", "b", "c")).collect(),
)
# test collect map of struct, array
# +-------------------+------------------------+
# | map(a, g)| map(a, struct(b, g))|
# +-------------------+------------------------+
# |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
# | {2 -> [1, 3]}| {2 -> {5, [1, 3]}}|
# | {3 -> [null]}| {3 -> {6, [null]}}|
# +-------------------+------------------------+
self.assertEqual(
cdf.select(CF.create_map("a", "g"), CF.create_map("a", CF.struct("b", "g"))).collect(),
sdf.select(SF.create_map("a", "g"), SF.create_map("a", SF.struct("b", "g"))).collect(),
)
# test collect struct
# +------------------+--------------------------+
# |struct(a, b, c, d)| struct(e, f, g)|
# +------------------+--------------------------+
# | {1, 4, 0, 8}|{true, true, [1, null, 3]}|
# | {2, 5, -1, null}| {false, null, [1, 3]}|
# | {3, 6, null, 0}| {false, null, [null]}|
# +------------------+--------------------------+
self.assertEqual(
cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f", "g")).collect(),
sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f", "g")).collect(),
)
# test collect nested struct
# +------------------------------------------+--------------------------+----------------------------+ # noqa
# |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c, d))| struct(e, f, struct(g))| # noqa
# +------------------------------------------+--------------------------+----------------------------+ # noqa
# | {1, {1, {0, {8}}}}| {1, 4, {0, 8}}|{true, true, {[1, null, 3]}}| # noqa
# | {2, {2, {-1, {null}}}}| {2, 5, {-1, null}}| {false, null, {[1, 3]}}| # noqa
# | {3, {3, {null, {0}}}}| {3, 6, {null, 0}}| {false, null, {[null]}}| # noqa
# +------------------------------------------+--------------------------+----------------------------+ # noqa
self.assertEqual(
cdf.select(
CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
CF.struct("a", "b", CF.struct("c", "d")),
CF.struct("e", "f", CF.struct("g")),
).collect(),
sdf.select(
SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
SF.struct("a", "b", SF.struct("c", "d")),
SF.struct("e", "f", SF.struct("g")),
).collect(),
)
# test collect struct containing array, map
# +--------------------------------------------+
# | struct(a, struct(a, struct(g, struct(h))))|
# +--------------------------------------------+
# |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
# | {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
# | {3, {3, {[null], {null}}}}|
# +--------------------------------------------+
self.assertEqual(
cdf.select(
CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
).collect(),
sdf.select(
SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
).collect(),
)
def test_collect_binary_type(self):
"""Test that df.collect() respects binary_as_bytes configuration for server-side data"""
query = """
SELECT * FROM VALUES
(CAST('hello' AS BINARY)),
(CAST('world' AS BINARY))
AS tab(b)
"""
for conf_value in ["true", "false"]:
expected_type = bytes if conf_value == "true" else bytearray
with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}):
connect_rows = self.connect.sql(query).collect()
self.assertEqual(len(connect_rows), 2)
for row in connect_rows:
self.assertIsInstance(row.b, expected_type)
spark_rows = self.spark.sql(query).collect()
self.assertEqual(len(spark_rows), 2)
for row in spark_rows:
self.assertIsInstance(row.b, expected_type)
def test_to_local_iterator_binary_type(self):
"""Test that df.toLocalIterator() respects binary_as_bytes configuration"""
query = """
SELECT * FROM VALUES
(CAST('data1' AS BINARY)),
(CAST('data2' AS BINARY))
AS tab(b)
"""
for conf_value in ["true", "false"]:
expected_type = bytes if conf_value == "true" else bytearray
with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}):
connect_count = 0
for row in self.connect.sql(query).toLocalIterator():
self.assertIsInstance(row.b, expected_type)
connect_count += 1
self.assertEqual(connect_count, 2)
spark_count = 0
for row in self.spark.sql(query).toLocalIterator():
self.assertIsInstance(row.b, expected_type)
spark_count += 1
self.assertEqual(spark_count, 2)
def test_foreach_partition_binary_type(self):
"""Test that df.foreachPartition() respects binary_as_bytes configuration
Since foreachPartition() runs on executors and cannot return data to the driver,
we test by ensuring the function doesn't throw exceptions when it expects the correct types.
"""
query = """
SELECT * FROM VALUES
(CAST('partition1' AS BINARY)),
(CAST('partition2' AS BINARY))
AS tab(b)
"""
for conf_value in ["true", "false"]:
expected_type = bytes if conf_value == "true" else bytearray
expected_type_name = "bytes" if conf_value == "true" else "bytearray"
with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes": conf_value}):
def assert_type(iterator):
count = 0
for row in iterator:
# This will raise an exception if the type is not as expected
assert isinstance(
row.b, expected_type
), f"Expected {expected_type_name}, got {type(row.b).__name__}"
count += 1
# Ensure we actually processed rows
assert count > 0, "No rows were processed"
self.connect.sql(query).foreachPartition(assert_type)
self.spark.sql(query).foreachPartition(assert_type)
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_collection import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)