blob: 6b9d81bfa0f50418701df8634c816663fca19299 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import unittest
from zoneinfo import ZoneInfo
from pyspark.errors import PySparkValueError
from pyspark.sql.conversion import (
ArrowArrayToPandasConversion,
ArrowTableToRowsConversion,
LocalDataToArrowConversion,
ArrowArrayConversion,
ArrowBatchTransformer,
)
from pyspark.sql.types import (
ArrayType,
BinaryType,
GeographyType,
GeometryType,
IntegerType,
MapType,
NullType,
Row,
StringType,
StructField,
StructType,
UserDefinedType,
)
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT, PythonOnlyPoint, PythonOnlyUDT
from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message
class ScoreUDT(UserDefinedType):
@classmethod
def sqlType(cls):
return IntegerType()
def serialize(self, obj):
return obj.score
def deserialize(self, datum):
return Score(datum)
class Score:
__UDT__ = ScoreUDT()
def __init__(self, score):
self.score = score
def __eq__(self, other):
return self.score == other.score
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ArrowBatchTransformerTests(unittest.TestCase):
def test_flatten_struct_basic(self):
"""Test flattening a struct column into separate columns."""
import pyarrow as pa
struct_array = pa.StructArray.from_arrays(
[pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
names=["x", "y"],
)
batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
flattened = ArrowBatchTransformer.flatten_struct(batch)
self.assertEqual(flattened.num_columns, 2)
self.assertEqual(flattened.column(0).to_pylist(), [1, 2, 3])
self.assertEqual(flattened.column(1).to_pylist(), ["a", "b", "c"])
self.assertEqual(flattened.schema.names, ["x", "y"])
def test_flatten_struct_empty_batch(self):
"""Test flattening an empty batch."""
import pyarrow as pa
struct_type = pa.struct([("x", pa.int64()), ("y", pa.string())])
struct_array = pa.array([], type=struct_type)
batch = pa.RecordBatch.from_arrays([struct_array], ["_0"])
flattened = ArrowBatchTransformer.flatten_struct(batch)
self.assertEqual(flattened.num_rows, 0)
self.assertEqual(flattened.num_columns, 2)
def test_wrap_struct_basic(self):
"""Test wrapping columns into a struct."""
import pyarrow as pa
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array(["a", "b", "c"])],
names=["x", "y"],
)
wrapped = ArrowBatchTransformer.wrap_struct(batch)
self.assertEqual(wrapped.num_columns, 1)
self.assertEqual(wrapped.schema.names, ["_0"])
struct_col = wrapped.column(0)
self.assertEqual(len(struct_col), 3)
self.assertEqual(struct_col.field(0).to_pylist(), [1, 2, 3])
self.assertEqual(struct_col.field(1).to_pylist(), ["a", "b", "c"])
def test_wrap_struct_empty_columns(self):
"""Test wrapping a batch with no columns."""
import pyarrow as pa
schema = pa.schema([])
batch = pa.RecordBatch.from_arrays([], schema=schema)
wrapped = ArrowBatchTransformer.wrap_struct(batch)
self.assertEqual(wrapped.num_columns, 1)
self.assertEqual(wrapped.num_rows, 0)
def test_wrap_struct_empty_batch(self):
"""Test wrapping an empty batch with schema."""
import pyarrow as pa
schema = pa.schema([("x", pa.int64()), ("y", pa.string())])
batch = pa.RecordBatch.from_arrays(
[pa.array([], type=pa.int64()), pa.array([], type=pa.string())],
schema=schema,
)
wrapped = ArrowBatchTransformer.wrap_struct(batch)
self.assertEqual(wrapped.num_rows, 0)
self.assertEqual(wrapped.num_columns, 1)
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ConversionTests(unittest.TestCase):
def test_conversion(self):
data = [
# Schema, Test cases (Before, After_If_Different)
(NullType(), (None,)),
(IntegerType(), (1,), (None,)),
((IntegerType(), {"nullable": False}), (1,)),
(StringType(), ("a",)),
(BinaryType(), (b"a",)),
(GeographyType("ANY"), (None,)),
(GeometryType("ANY"), (None,)),
(ArrayType(IntegerType()), ([1, None],)),
(ArrayType(IntegerType(), containsNull=False), ([1, 2],)),
(ArrayType(BinaryType()), ([b"a", b"b"],)),
(MapType(StringType(), IntegerType()), ({"a": 1, "b": None},)),
(
MapType(StringType(), IntegerType(), valueContainsNull=False),
({"a": 1},),
),
(MapType(StringType(), BinaryType()), ({"a": b"a"},)),
(
StructType(
[
StructField("i", IntegerType()),
StructField("i_n", IntegerType()),
StructField("ii", IntegerType(), nullable=False),
StructField("s", StringType()),
StructField("b", BinaryType()),
]
),
((1, None, 1, "a", b"a"), Row(i=1, i_n=None, ii=1, s="a", b=b"a")),
(
{"b": b"a", "s": "a", "ii": 1, "in": None, "i": 1},
Row(i=1, i_n=None, ii=1, s="a", b=b"a"),
),
),
(ExamplePointUDT(), (ExamplePoint(1.0, 1.0),)),
(ScoreUDT(), (Score(1),)),
]
schema = StructType()
input_row = []
expected = []
index = 0
for row_schema, *tests in data:
if isinstance(row_schema, tuple):
row_schema, kwargs = row_schema
else:
kwargs = {}
for test in tests:
if len(test) == 1:
before, after = test[0], test[0]
else:
before, after = test
schema.add(f"{row_schema.simpleString()}_{index}", row_schema, **kwargs)
input_row.append(before)
expected.append(after)
index += 1
tbl = LocalDataToArrowConversion.convert(
[tuple(input_row)], schema, use_large_var_types=False
)
actual = ArrowTableToRowsConversion.convert(tbl, schema)
for a, e in zip(
actual[0],
expected,
):
with self.subTest(expected=e):
self.assertEqual(a, e)
def test_none_as_row(self):
schema = StructType([StructField("x", IntegerType())])
tbl = LocalDataToArrowConversion.convert([None], schema, use_large_var_types=False)
actual = ArrowTableToRowsConversion.convert(tbl, schema)
self.assertEqual(actual[0], Row(x=None))
def test_return_as_tuples(self):
schema = StructType([StructField("x", IntegerType())])
tbl = LocalDataToArrowConversion.convert([(1,)], schema, use_large_var_types=False)
actual = ArrowTableToRowsConversion.convert(tbl, schema, return_as_tuples=True)
self.assertEqual(actual[0], (1,))
schema = StructType()
tbl = LocalDataToArrowConversion.convert([tuple()], schema, use_large_var_types=False)
actual = ArrowTableToRowsConversion.convert(tbl, schema, return_as_tuples=True)
self.assertEqual(actual[0], tuple())
def test_binary_as_bytes_conversion(self):
data = [
(
str(i).encode(), # simple binary
[str(j).encode() for j in range(3)], # array of binary
{str(j): str(j).encode() for j in range(2)}, # map with binary values
{"b": str(i).encode()}, # struct with binary
)
for i in range(2)
]
schema = (
StructType()
.add("b", BinaryType())
.add("arr_b", ArrayType(BinaryType()))
.add("map_b", MapType(StringType(), BinaryType()))
.add("struct_b", StructType().add("b", BinaryType()))
)
tbl = LocalDataToArrowConversion.convert(data, schema, use_large_var_types=False)
for binary_as_bytes, expected_type in [(True, bytes), (False, bytearray)]:
actual = ArrowTableToRowsConversion.convert(
tbl, schema, binary_as_bytes=binary_as_bytes
)
for row in actual:
# Simple binary field
self.assertIsInstance(row.b, expected_type)
# Array elements
for elem in row.arr_b:
self.assertIsInstance(elem, expected_type)
# Map values
for value in row.map_b.values():
self.assertIsInstance(value, expected_type)
# Struct field
self.assertIsInstance(row.struct_b.b, expected_type)
def test_invalid_conversion(self):
data = [
(NullType(), 1),
(ArrayType(IntegerType(), containsNull=False), [1, None]),
(ArrayType(ScoreUDT(), containsNull=False), [None]),
]
for row_schema, value in data:
schema = StructType([StructField("x", row_schema)])
with self.assertRaises(PySparkValueError):
LocalDataToArrowConversion.convert([(value,)], schema, use_large_var_types=False)
def test_arrow_array_localize_tz(self):
import pyarrow as pa
tz1 = ZoneInfo("Asia/Singapore")
tz2 = ZoneInfo("America/Los_Angeles")
tz3 = ZoneInfo("UTC")
ts0 = datetime.datetime(2026, 1, 5, 15, 0, 1)
ts1 = datetime.datetime(2026, 1, 5, 15, 0, 1, tzinfo=tz1)
ts2 = datetime.datetime(2026, 1, 5, 15, 0, 1, tzinfo=tz2)
ts3 = datetime.datetime(2026, 1, 5, 15, 0, 1, tzinfo=tz3)
# non-timestampe types
for arr in [
pa.array([1, 2]),
pa.array([["x", "y"]]),
pa.array([[[3.0, 4.0]]]),
pa.StructArray.from_arrays([pa.array([1, 2]), pa.array(["x", "y"])], names=["a", "b"]),
pa.array([{1: None, 2: "x"}], type=pa.map_(pa.int32(), pa.string())),
]:
output = ArrowArrayConversion.localize_tz(arr)
self.assertTrue(output is arr, f"MUST not generate a new array {output.tolist()}")
# timestampe types
for arr, expected in [
(pa.array([ts0, None]), pa.array([ts0, None])), # ts-ntz
(pa.array([ts1, None]), pa.array([ts0, None])), # ts-ltz
(pa.array([[ts2, None]]), pa.array([[ts0, None]])), # array<ts-ltz>
(pa.array([[[ts3, None]]]), pa.array([[[ts0, None]]])), # array<array<ts-ltz>>
(
pa.StructArray.from_arrays(
[pa.array([1, 2]), pa.array([ts0, None]), pa.array([ts1, None])],
names=["a", "b", "c"],
),
pa.StructArray.from_arrays(
[pa.array([1, 2]), pa.array([ts0, None]), pa.array([ts0, None])],
names=["a", "b", "c"],
),
), # struct<int, ts-ntz, ts-ltz>
(
pa.StructArray.from_arrays(
[pa.array([1, 2]), pa.array([[ts2], [None]])], names=["a", "b"]
),
pa.StructArray.from_arrays(
[pa.array([1, 2]), pa.array([[ts0], [None]])], names=["a", "b"]
),
), # struct<int, array<ts-ltz>>
(
pa.StructArray.from_arrays(
[
pa.array([ts2, None]),
pa.StructArray.from_arrays(
[pa.array(["a", "b"]), pa.array([[ts3], [None]])], names=["x", "y"]
),
],
names=["a", "b"],
),
pa.StructArray.from_arrays(
[
pa.array([ts0, None]),
pa.StructArray.from_arrays(
[pa.array(["a", "b"]), pa.array([[ts0], [None]])], names=["x", "y"]
),
],
names=["a", "b"],
),
), # struct<ts-ltz, struct<str, array<ts-ltz>>>
(
pa.array(
[{1: None, 2: ts1}],
type=pa.map_(pa.int32(), pa.timestamp("us", tz=tz1)),
),
pa.array(
[{1: None, 2: ts0}],
type=pa.map_(pa.int32(), pa.timestamp("us")),
),
), # map<int, ts-ltz>
(
pa.array(
[{1: [None], 2: [ts2, None]}],
type=pa.map_(pa.int32(), pa.list_(pa.timestamp("us", tz=tz2))),
),
pa.array(
[{1: [None], 2: [ts0, None]}],
type=pa.map_(pa.int32(), pa.list_(pa.timestamp("us"))),
),
), # map<int, array<ts-ltz>>
]:
output = ArrowArrayConversion.localize_tz(arr)
self.assertEqual(output, expected, f"{output.tolist()} != {expected.tolist()}")
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ArrowArrayToPandasConversionTests(unittest.TestCase):
def test_udt_convert_numpy(self):
import pyarrow as pa
udt = ExamplePointUDT()
# basic conversion with nulls
arr = pa.array([[1.0, 2.0], None, [3.0, 4.0]], type=pa.list_(pa.float64()))
result = ArrowArrayToPandasConversion.convert_numpy(arr, udt, ser_name="my_point")
self.assertIsInstance(result.iloc[0], ExamplePoint)
self.assertEqual(result.iloc[0], ExamplePoint(1.0, 2.0))
self.assertIsNone(result.iloc[1])
self.assertEqual(result.iloc[2], ExamplePoint(3.0, 4.0))
self.assertEqual(result.name, "my_point")
# empty
result = ArrowArrayToPandasConversion.convert_numpy(
pa.array([], type=pa.list_(pa.float64())), udt
)
self.assertEqual(len(result), 0)
# PythonOnlyUDT
result = ArrowArrayToPandasConversion.convert_numpy(
pa.array([[5.0, 6.0]], type=pa.list_(pa.float64())), PythonOnlyUDT()
)
self.assertIsInstance(result.iloc[0], PythonOnlyPoint)
self.assertEqual(result.iloc[0], PythonOnlyPoint(5.0, 6.0))
def test_udt_chunked_array(self):
import pyarrow as pa
chunk1 = pa.array([[1.0, 2.0]], type=pa.list_(pa.float64()))
chunk2 = pa.array([[3.0, 4.0]], type=pa.list_(pa.float64()))
chunked = pa.chunked_array([chunk1, chunk2])
result = ArrowArrayToPandasConversion.convert_numpy(chunked, ExamplePointUDT())
self.assertEqual(result.iloc[0], ExamplePoint(1.0, 2.0))
self.assertEqual(result.iloc[1], ExamplePoint(3.0, 4.0))
if __name__ == "__main__":
from pyspark.testing import main
main()