blob: 141d9f37148e12b0590a2107f70c28b14b297193 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Iterator, Optional
import pyarrow as pa
import pyspark.sql
from pyspark.sql.types import StructType, StructField, BinaryType
from pyspark.sql.pandas.types import to_arrow_schema
def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]:
"""Return all the partitions as Arrow arrays in an Iterator."""
# We will be using mapInArrow to convert each partition to Arrow RecordBatches.
# The return type of the function will be a single binary column containing
# the serialized RecordBatch in Arrow IPC format.
binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)])
def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
"""
A generator function that converts RecordBatches to serialized Arrow IPC format.
Spark sends each partition as an iterator of RecordBatches. In order to return
the entire partition as a stream of Arrow RecordBatches, we need to serialize
each RecordBatch to Arrow IPC format and yield it as a single binary blob.
"""
# The size of the batch can be controlled by the Spark config
# `spark.sql.execution.arrow.maxRecordsPerBatch`.
for arrow_batch in batch_iter:
# We create an in-memory byte stream to hold the serialized batch
sink = pa.BufferOutputStream()
# Write the batch to the stream using Arrow IPC format
with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:
writer.write_batch(arrow_batch)
buf = sink.getvalue()
# The second buffer contains the offsets we are manually creating.
offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]
null_bitmap = None
# Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return
# signature. This serializes the whole batch into a single pyarrow serialized cell.
storage_arr = pa.Array.from_buffers(
type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf]
)
yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"])
# Convert all partitions to Arrow RecordBatches and map to binary blobs.
byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)
# A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one.
for row in byte_df.toLocalIterator():
with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:
for batch in reader:
# Each batch corresponds to a chunk of data in the partition.
yield batch
class SparkArrowCStreamer:
"""
A class that implements that __arrow_c_stream__ protocol for Spark partitions.
This class is implemented in a way that allows consumers to consume each partition
one at a time without materializing all partitions at once on the driver side.
"""
def __init__(self, df: pyspark.sql.DataFrame):
self._df = df
self._schema = to_arrow_schema(df.schema)
def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object:
"""
Return the Arrow C stream for the dataframe partitions.
"""
reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(
self._schema, _get_arrow_array_partition_stream(self._df)
)
return reader.__arrow_c_stream__(requested_schema=requested_schema)