| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """ |
| Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. |
| """ |
| |
| from itertools import groupby |
| from typing import IO, TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Tuple |
| |
| from pyspark.errors import PySparkRuntimeError, PySparkValueError |
| from pyspark.serializers import ( |
| Serializer, |
| read_int, |
| write_int, |
| UTF8Deserializer, |
| CPickleSerializer, |
| ) |
| from pyspark.sql import Row |
| from pyspark.sql.conversion import ( |
| ArrowBatchTransformer, |
| PandasToArrowConversion, |
| ) |
| from pyspark.sql.pandas.types import ( |
| from_arrow_schema, |
| to_arrow_type, |
| ) |
| from pyspark.sql.types import ( |
| DataType, |
| StringType, |
| StructType, |
| BinaryType, |
| StructField, |
| LongType, |
| IntegerType, |
| ) |
| |
| if TYPE_CHECKING: |
| import pandas as pd |
| import pyarrow as pa |
| |
| |
| def _normalize_packed(packed): |
| """ |
| Normalize UDF output to a uniform tuple-of-tuples form. |
| |
| Iterator UDFs yield a single (series, spark_type) tuple directly, |
| while batched UDFs return a tuple of tuples ((s1, t1), (s2, t2), ...). |
| This function normalizes both forms to a tuple of tuples. |
| """ |
| if len(packed) == 2 and isinstance(packed[1], DataType): |
| return (packed,) |
| return tuple(packed) |
| |
| |
| class SpecialLengths: |
| END_OF_DATA_SECTION = -1 |
| PYTHON_EXCEPTION_THROWN = -2 |
| TIMING_DATA = -3 |
| END_OF_STREAM = -4 |
| NULL = -5 |
| START_ARROW_STREAM = -6 |
| |
| |
| class ArrowCollectSerializer(Serializer): |
| """ |
| Deserialize a stream of batches followed by batch order information. Used in |
| PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython() |
| in the JVM. |
| """ |
| |
| def __init__(self): |
| self.serializer = ArrowStreamSerializer() |
| |
| def dump_stream(self, iterator, stream): |
| return self.serializer.dump_stream(iterator, stream) |
| |
| def load_stream(self, stream): |
| """ |
| Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields |
| a list of indices that can be used to put the RecordBatches in the correct order. |
| """ |
| # load the batches |
| for batch in self.serializer.load_stream(stream): |
| yield batch |
| |
| # load the batch order indices or propagate any error that occurred in the JVM |
| num = read_int(stream) |
| if num == -1: |
| error_msg = UTF8Deserializer().loads(stream) |
| raise PySparkRuntimeError( |
| errorClass="ERROR_OCCURRED_WHILE_CALLING", |
| messageParameters={ |
| "func_name": "ArrowCollectSerializer.load_stream", |
| "error_msg": error_msg, |
| }, |
| ) |
| batch_order = [] |
| for i in range(num): |
| index = read_int(stream) |
| batch_order.append(index) |
| yield batch_order |
| |
| def __repr__(self): |
| return "ArrowCollectSerializer(%s)" % self.serializer |
| |
| |
| class ArrowStreamSerializer(Serializer): |
| """ |
| Serializes Arrow record batches as a plain stream. |
| |
| Parameters |
| ---------- |
| write_start_stream : bool |
| If True, writes the START_ARROW_STREAM marker before the first |
| output batch. Default False. |
| """ |
| |
| def __init__(self, write_start_stream: bool = False) -> None: |
| super().__init__() |
| self._write_start_stream: bool = write_start_stream |
| |
| def dump_stream(self, iterator: Iterable["pa.RecordBatch"], stream: IO[bytes]) -> None: |
| """Optionally prepend START_ARROW_STREAM, then write batches.""" |
| iterator = iter(iterator) |
| if self._write_start_stream: |
| iterator = self._write_stream_start(iterator, stream) |
| import pyarrow as pa |
| |
| writer = None |
| try: |
| for batch in iterator: |
| if writer is None: |
| writer = pa.RecordBatchStreamWriter(stream, batch.schema) |
| writer.write_batch(batch) |
| finally: |
| if writer is not None: |
| writer.close() |
| |
| def load_stream(self, stream: IO[bytes]) -> Iterator["pa.RecordBatch"]: |
| """Load batches from a plain Arrow stream.""" |
| import pyarrow as pa |
| |
| reader = pa.ipc.open_stream(stream) |
| for batch in reader: |
| yield batch |
| |
| def _write_stream_start( |
| self, batch_iterator: Iterator["pa.RecordBatch"], stream: IO[bytes] |
| ) -> Iterator["pa.RecordBatch"]: |
| """Write START_ARROW_STREAM before the first batch, then pass batches through.""" |
| import itertools |
| |
| first = next(batch_iterator, None) |
| if first is None: |
| return |
| |
| # Signal the JVM after the first batch succeeds, so errors during |
| # batch creation can be reported before the Arrow stream starts. |
| write_int(SpecialLengths.START_ARROW_STREAM, stream) |
| yield from itertools.chain([first], batch_iterator) |
| |
| def __repr__(self) -> str: |
| return "ArrowStreamSerializer(write_start_stream=%s)" % self._write_start_stream |
| |
| |
| class ArrowStreamGroupSerializer(ArrowStreamSerializer): |
| """ |
| Extends :class:`ArrowStreamSerializer` with group-count protocol for loading |
| grouped Arrow record batches (1 dataframe per group). |
| """ |
| |
| def load_stream(self, stream: IO[bytes]) -> Iterator[Iterator["pa.RecordBatch"]]: |
| """Yield one iterator of record batches per group from the stream.""" |
| while dataframes_in_group := read_int(stream): |
| if dataframes_in_group == 1: |
| yield ArrowStreamSerializer.load_stream(self, stream) |
| elif dataframes_in_group > 0: |
| raise PySparkValueError( |
| errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", |
| messageParameters={"dataframes_in_group": str(dataframes_in_group)}, |
| ) |
| |
| |
| class ArrowStreamCoGroupSerializer(ArrowStreamSerializer): |
| """ |
| Extends :class:`ArrowStreamSerializer` with group-count protocol for loading |
| cogrouped Arrow record batches (2 dataframes per group). |
| """ |
| |
| def load_stream( |
| self, stream: IO[bytes] |
| ) -> Iterator[Tuple[List["pa.RecordBatch"], List["pa.RecordBatch"]]]: |
| """Yield pairs of (left_batches, right_batches) from the stream.""" |
| while dataframes_in_group := read_int(stream): |
| if dataframes_in_group == 2: |
| # Must eagerly load each dataframe to maintain correct stream position |
| yield ( |
| list(ArrowStreamSerializer.load_stream(self, stream)), |
| list(ArrowStreamSerializer.load_stream(self, stream)), |
| ) |
| elif dataframes_in_group > 0: |
| raise PySparkValueError( |
| errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", |
| messageParameters={"dataframes_in_group": str(dataframes_in_group)}, |
| ) |
| |
| |
| class ArrowStreamUDFSerializer(ArrowStreamSerializer): |
| """ |
| Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch |
| for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. |
| """ |
| |
| def load_stream(self, stream): |
| """ |
| Flatten the struct into Arrow's record batches. |
| """ |
| batches = super().load_stream(stream) |
| flattened = map(ArrowBatchTransformer.flatten_struct, batches) |
| return map(lambda b: [b], flattened) |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. |
| """ |
| batches = self._write_stream_start( |
| (ArrowBatchTransformer.wrap_struct(x[0]) for x in iterator), stream |
| ) |
| return super().dump_stream(batches, stream) |
| |
| |
| class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): |
| """ |
| Same as :class:`ArrowStreamUDFSerializer` but it does not flatten when loading batches. |
| """ |
| |
| def load_stream(self, stream): |
| return ArrowStreamSerializer.load_stream(self, stream) |
| |
| |
| class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): |
| """ |
| Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. |
| """ |
| |
| def __init__(self, *, table_arg_offsets=None): |
| super().__init__() |
| self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] |
| |
| def load_stream(self, stream): |
| """ |
| Flatten the struct into Arrow's record batches. |
| """ |
| for batch in super().load_stream(stream): |
| # For each column: flatten struct columns at table_arg_offsets into RecordBatch, |
| # keep other columns as Array |
| yield [ |
| ( |
| ArrowBatchTransformer.flatten_struct(batch, column_index=i) |
| if i in self.table_arg_offsets |
| else batch.column(i) |
| ) |
| for i in range(batch.num_columns) |
| ] |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Override to handle type coercion for ArrowUDTF outputs. |
| ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. |
| """ |
| import pyarrow as pa |
| |
| def apply_type_coercion(): |
| for batch, arrow_return_type in iterator: |
| assert isinstance(arrow_return_type, pa.StructType), ( |
| f"Expected pa.StructType, got {type(arrow_return_type)}" |
| ) |
| coerced_batch = ArrowBatchTransformer.enforce_schema( |
| batch, pa.schema(arrow_return_type), safecheck=True |
| ) |
| yield coerced_batch, arrow_return_type |
| |
| return super().dump_stream(apply_type_coercion(), stream) |
| |
| |
| class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): |
| """ |
| Serializer for grouped Arrow UDFs. |
| |
| Deserializes: |
| ``Iterator[Iterator[pa.RecordBatch]]`` - one inner iterator per group. |
| Each batch contains a single struct column. |
| |
| Serializes: |
| ``Iterator[Tuple[Iterator[pa.RecordBatch], pa.DataType]]`` |
| Each tuple contains iterator of flattened batches and their Arrow type. |
| |
| Used by: |
| - SQL_GROUPED_MAP_ARROW_UDF |
| - SQL_GROUPED_MAP_ARROW_ITER_UDF |
| |
| Parameters |
| ---------- |
| assign_cols_by_name : bool |
| If True, reorder serialized columns by schema name. |
| """ |
| |
| def __init__(self, *, assign_cols_by_name): |
| super().__init__() |
| self._assign_cols_by_name = assign_cols_by_name |
| |
| def load_stream(self, stream): |
| """ |
| Load grouped Arrow record batches from stream. |
| """ |
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): |
| yield batches |
| # Make sure the batches are fully iterated before getting the next group |
| for _ in batches: |
| pass |
| |
| def dump_stream(self, iterator, stream): |
| import pyarrow as pa |
| |
| # flatten inner list [([pa.RecordBatch], arrow_type)] into [(pa.RecordBatch, arrow_type)] |
| # so strip off inner iterator induced by ArrowStreamUDFSerializer.load_stream |
| batch_iter = ( |
| (batch, arrow_type) |
| for batches, arrow_type in iterator # tuple constructed in wrap_grouped_map_arrow_udf |
| for batch in batches |
| ) |
| |
| if self._assign_cols_by_name: |
| batch_iter = ( |
| ( |
| pa.RecordBatch.from_arrays( |
| [batch.column(field.name) for field in arrow_type], |
| names=[field.name for field in arrow_type], |
| ), |
| arrow_type, |
| ) |
| for batch, arrow_type in batch_iter |
| ) |
| |
| super().dump_stream(batch_iter, stream) |
| |
| |
| class ArrowStreamPandasSerializer(ArrowStreamSerializer): |
| """ |
| Serializes pandas.Series as Arrow data with Arrow streaming format. |
| |
| Parameters |
| ---------- |
| timezone : str |
| A timezone to respect when handling timestamp values |
| safecheck : bool |
| If True, conversion from Arrow to Pandas checks for overflow/truncation |
| int_to_decimal_coercion_enabled : bool |
| If True, applies additional coercions in Python before converting to Arrow. |
| This has performance penalties. |
| prefers_large_types : bool |
| If True, prefer large Arrow types (e.g., large_string instead of string). |
| struct_in_pandas : str, optional |
| How to represent struct in pandas ("dict", "row", etc.). Default is "dict". |
| ndarray_as_list : bool, optional |
| Whether to convert ndarray as list. Default is False. |
| prefer_int_ext_dtype : bool, optional |
| Whether to convert integers to Pandas ExtensionDType. Default is False. |
| df_for_struct : bool, optional |
| If True, convert struct columns to DataFrame instead of Series. Default is False. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| int_to_decimal_coercion_enabled: bool = False, |
| prefers_large_types: bool = False, |
| struct_in_pandas: str = "dict", |
| ndarray_as_list: bool = False, |
| prefer_int_ext_dtype: bool = False, |
| df_for_struct: bool = False, |
| input_type: Optional["StructType"] = None, |
| arrow_cast: bool = False, |
| ): |
| super().__init__() |
| self._timezone = timezone |
| self._safecheck = safecheck |
| self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled |
| self._prefers_large_types = prefers_large_types |
| self._struct_in_pandas = struct_in_pandas |
| self._ndarray_as_list = ndarray_as_list |
| self._prefer_int_ext_dtype = prefer_int_ext_dtype |
| self._df_for_struct = df_for_struct |
| if input_type is not None: |
| assert isinstance(input_type, StructType) |
| self._input_type = input_type |
| self._arrow_cast = arrow_cast |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Make ArrowRecordBatches from Pandas Series and serialize. |
| Each element in iterator is: |
| - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), (s2, t2), ...) |
| - For iterator UDFs: single (series, spark_type) tuple directly |
| """ |
| |
| def create_batch( |
| series_tuples: Tuple[Tuple["pd.Series", DataType], ...], |
| ) -> "pa.RecordBatch": |
| series_data = [s for s, _ in series_tuples] |
| types = [t for _, t in series_tuples] |
| schema = StructType([StructField(f"_{i}", t) for i, t in enumerate(types)]) |
| return PandasToArrowConversion.convert( |
| series_data, |
| schema, |
| timezone=self._timezone, |
| safecheck=self._safecheck, |
| prefers_large_types=self._prefers_large_types, |
| int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, |
| ) |
| |
| super().dump_stream( |
| (create_batch(_normalize_packed(packed)) for packed in iterator), stream |
| ) |
| |
| def load_stream(self, stream): |
| """ |
| Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. |
| """ |
| yield from map( |
| lambda batch: ArrowBatchTransformer.to_pandas( |
| batch, |
| timezone=self._timezone, |
| schema=self._input_type, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ), |
| super().load_stream(stream), |
| ) |
| |
| def __repr__(self): |
| return "ArrowStreamPandasSerializer" |
| |
| |
| class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): |
| """ |
| Serializer used by Python worker to evaluate Pandas UDFs |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| df_for_struct: bool = False, |
| struct_in_pandas: str = "dict", |
| ndarray_as_list: bool = False, |
| prefer_int_ext_dtype: bool = False, |
| arrow_cast: bool = False, |
| input_type: Optional[StructType] = None, |
| int_to_decimal_coercion_enabled: bool = False, |
| prefers_large_types: bool = False, |
| ignore_unexpected_complex_type_values: bool = False, |
| is_legacy: bool = False, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| prefers_large_types=prefers_large_types, |
| struct_in_pandas=struct_in_pandas, |
| ndarray_as_list=ndarray_as_list, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| df_for_struct=df_for_struct, |
| input_type=input_type, |
| arrow_cast=arrow_cast, |
| ) |
| self._assign_cols_by_name = assign_cols_by_name |
| self._ignore_unexpected_complex_type_values = ignore_unexpected_complex_type_values |
| self._is_legacy = is_legacy |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. |
| This should be sent after creating the first record batch so in case of an error, it can |
| be sent back to the JVM before the Arrow stream starts. |
| |
| Each element in iterator is: |
| - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), (s2, t2), ...) |
| - For iterator UDFs: single (series, spark_type) tuple directly |
| """ |
| import pandas as pd |
| |
| def create_batch( |
| series_tuples: Tuple[Tuple["pd.Series", DataType], ...], |
| ) -> "pa.RecordBatch": |
| # When struct_in_pandas="dict", UDF must return DataFrame for struct types |
| if self._struct_in_pandas == "dict": |
| for s, spark_type in series_tuples: |
| if isinstance(spark_type, StructType) and not isinstance(s, pd.DataFrame): |
| raise PySparkValueError( |
| "Invalid return type. Please make sure that the UDF returns a " |
| "pandas.DataFrame when the specified return type is StructType." |
| ) |
| |
| series_data = [s for s, _ in series_tuples] |
| types = [t for _, t in series_tuples] |
| schema = StructType([StructField(f"_{i}", t) for i, t in enumerate(types)]) |
| return PandasToArrowConversion.convert( |
| series_data, |
| schema, |
| timezone=self._timezone, |
| safecheck=self._safecheck, |
| arrow_cast=self._arrow_cast, |
| prefers_large_types=self._prefers_large_types, |
| assign_cols_by_name=self._assign_cols_by_name, |
| int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, |
| ignore_unexpected_complex_type_values=self._ignore_unexpected_complex_type_values, |
| is_legacy=self._is_legacy, |
| ) |
| |
| batches = self._write_stream_start( |
| (create_batch(_normalize_packed(packed)) for packed in iterator), |
| stream, |
| ) |
| return ArrowStreamSerializer.dump_stream(self, batches, stream) |
| |
| def __repr__(self): |
| return "ArrowStreamPandasUDFSerializer" |
| |
| |
| class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): |
| """ |
| Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| input_type, |
| prefer_int_ext_dtype, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| # The output pandas DataFrame's columns are unnamed. |
| assign_cols_by_name=False, |
| # Set to 'False' to avoid converting struct type inputs into a pandas DataFrame. |
| df_for_struct=False, |
| # Defines how struct type inputs are converted. If set to "row", struct type inputs |
| # are converted into Rows. Without this setting, a struct type input would be treated |
| # as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1), |
| # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1} |
| # if struct_in_pandas="row", it becomes Row(name="Alice", age=1) |
| struct_in_pandas="row", |
| # When dealing with array type inputs, Arrow converts them into numpy.ndarrays. |
| # To ensure consistency across regular and arrow-optimized UDTFs, we further |
| # convert these numpy.ndarrays into Python lists. |
| ndarray_as_list=True, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| # Enables explicit casting for mismatched return types of Arrow Python UDTFs. |
| arrow_cast=True, |
| input_type=input_type, |
| # Enable additional coercions for UDTF serialization |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| # UDTF-specific: ignore unexpected complex type values in converter |
| ignore_unexpected_complex_type_values=True, |
| # Legacy UDTF pandas conversion: enables broader Arrow exception |
| # handling to allow more implicit type coercions |
| is_legacy=True, |
| ) |
| |
| def __repr__(self): |
| return "ArrowStreamPandasUDTFSerializer" |
| |
| |
| # Serializer for SQL_GROUPED_AGG_PANDAS_UDF, SQL_WINDOW_AGG_PANDAS_UDF, |
| # and SQL_GROUPED_AGG_PANDAS_ITER_UDF |
| class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer): |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| prefer_int_ext_dtype, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| assign_cols_by_name=assign_cols_by_name, |
| df_for_struct=False, |
| struct_in_pandas="dict", |
| ndarray_as_list=False, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| arrow_cast=True, |
| input_type=None, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| ) |
| |
| def load_stream(self, stream): |
| """ |
| Yield an iterator that produces one tuple of pandas.Series per batch. |
| Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to |
| process batches one by one without consuming all batches upfront. |
| """ |
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): |
| # Lazily read and convert Arrow batches to pandas Series one at a time |
| # from the stream. This avoids loading all batches into memory for the group |
| series_iter = map( |
| lambda batch: tuple( |
| ArrowBatchTransformer.to_pandas( |
| batch, |
| timezone=self._timezone, |
| schema=self._input_type, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ) |
| ), |
| batches, |
| ) |
| yield series_iter |
| # Make sure the batches are fully iterated before getting the next group |
| for _ in series_iter: |
| pass |
| |
| def __repr__(self): |
| return "ArrowStreamAggPandasUDFSerializer" |
| |
| |
| # Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF |
| class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| prefer_int_ext_dtype, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| assign_cols_by_name=assign_cols_by_name, |
| df_for_struct=False, |
| struct_in_pandas="dict", |
| ndarray_as_list=False, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| arrow_cast=True, |
| input_type=None, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| ) |
| |
| def load_stream(self, stream): |
| """ |
| Deserialize Grouped ArrowRecordBatches and yield raw Iterator[pa.RecordBatch]. |
| Each outer iterator element represents a group. |
| """ |
| for batches in ArrowStreamGroupSerializer.load_stream(self, stream): |
| yield batches |
| # Make sure the batches are fully iterated before getting the next group |
| for _ in batches: |
| pass |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Flatten the grouped iterator structure. |
| """ |
| # Flatten: Iterator[Iterator[[(df, spark_type)]]] -> Iterator[[(df, spark_type)]] |
| flattened_iter = (batch for generator in iterator for batch in generator) |
| super().dump_stream(flattened_iter, stream) |
| |
| def __repr__(self): |
| return "GroupPandasUDFSerializer" |
| |
| |
| class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): |
| """ |
| Serializes pyarrow.RecordBatch data with Arrow streaming format. |
| |
| Loads Arrow record batches as `[([pa.RecordBatch], [pa.RecordBatch])]` (one tuple per group) |
| and serializes `[([pa.RecordBatch], arrow_type)]`. |
| |
| Parameters |
| ---------- |
| assign_cols_by_name : bool |
| If True, then DataFrames will get columns by name |
| """ |
| |
| def load_stream(self, stream): |
| """ |
| Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. |
| """ |
| yield from ArrowStreamCoGroupSerializer.load_stream(self, stream) |
| |
| |
| class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): |
| def load_stream(self, stream): |
| """ |
| Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two |
| lists of pandas.Series. |
| """ |
| import pyarrow as pa |
| |
| for left_batches, right_batches in ArrowStreamCoGroupSerializer.load_stream(self, stream): |
| left_table = pa.Table.from_batches(left_batches) |
| right_table = pa.Table.from_batches(right_batches) |
| yield ( |
| ArrowBatchTransformer.to_pandas( |
| left_table, |
| timezone=self._timezone, |
| schema=from_arrow_schema(left_table.schema), |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ), |
| ArrowBatchTransformer.to_pandas( |
| right_table, |
| timezone=self._timezone, |
| schema=from_arrow_schema(right_table.schema), |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ), |
| ) |
| |
| |
| class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): |
| """ |
| Serializer used by Python worker to evaluate UDF for applyInPandasWithState. |
| |
| Parameters |
| ---------- |
| timezone : str |
| A timezone to respect when handling timestamp values |
| safecheck : bool |
| If True, conversion from Arrow to Pandas checks for overflow/truncation |
| assign_cols_by_name : bool |
| If True, then Pandas DataFrames will get columns by name |
| state_object_schema : StructType |
| The type of state object represented as Spark SQL type |
| arrow_max_records_per_batch : int |
| Limit of the number of records that can be written to a single ArrowRecordBatch in memory. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| prefer_int_ext_dtype, |
| state_object_schema, |
| arrow_max_records_per_batch, |
| prefers_large_var_types, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| assign_cols_by_name=assign_cols_by_name, |
| df_for_struct=False, |
| struct_in_pandas="dict", |
| ndarray_as_list=False, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| arrow_cast=True, |
| input_type=None, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| prefers_large_types=prefers_large_var_types, |
| ) |
| self.pickleSer = CPickleSerializer() |
| self.utf8_deserializer = UTF8Deserializer() |
| self.state_object_schema = state_object_schema |
| |
| self.result_count_df_type = StructType( |
| [ |
| StructField("dataCount", IntegerType()), |
| StructField("stateCount", IntegerType()), |
| ] |
| ) |
| |
| self.result_count_pdf_arrow_type = to_arrow_type( |
| self.result_count_df_type, timezone="UTC", prefers_large_types=prefers_large_var_types |
| ) |
| |
| self.result_state_df_type = StructType( |
| [ |
| StructField("properties", StringType()), |
| StructField("keyRowAsUnsafe", BinaryType()), |
| StructField("object", BinaryType()), |
| StructField("oldTimeoutTimestamp", LongType()), |
| ] |
| ) |
| |
| self.result_state_pdf_arrow_type = to_arrow_type( |
| self.result_state_df_type, timezone="UTC", prefers_large_types=prefers_large_var_types |
| ) |
| self.arrow_max_records_per_batch = ( |
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 |
| ) |
| |
| def load_stream(self, stream): |
| """ |
| Read ArrowRecordBatches from stream, deserialize them to populate a list of pair |
| (data chunk, state), and convert the data into a list of pandas.Series. |
| |
| Please refer the doc of inner function `gen_data_and_state` for more details how |
| this function works in overall. |
| |
| In addition, this function further groups the return of `gen_data_and_state` by the state |
| instance (same semantic as grouping by grouping key) and produces an iterator of data |
| chunks for each group, so that the caller can lazily materialize the data chunk. |
| """ |
| |
| import pyarrow as pa |
| import json |
| from itertools import groupby |
| from pyspark.sql.streaming.state import GroupState |
| |
| def construct_state(state_info_col): |
| """ |
| Construct state instance from the value of state information column. |
| """ |
| |
| state_info_col_properties = state_info_col["properties"] |
| state_info_col_key_row = state_info_col["keyRowAsUnsafe"] |
| state_info_col_object = state_info_col["object"] |
| |
| state_properties = json.loads(state_info_col_properties) |
| if state_info_col_object: |
| state_object = self.pickleSer.loads(state_info_col_object) |
| else: |
| state_object = None |
| state_properties["optionalValue"] = state_object |
| |
| return GroupState( |
| keyAsUnsafe=state_info_col_key_row, |
| valueSchema=self.state_object_schema, |
| **state_properties, |
| ) |
| |
| def gen_data_and_state(batches): |
| """ |
| Deserialize ArrowRecordBatches and return a generator of |
| `(a list of pandas.Series, state)`. |
| |
| The logic on deserialization is following: |
| |
| 1. Read the entire data part from Arrow RecordBatch. |
| 2. Read the entire state information part from Arrow RecordBatch. |
| 3. Loop through each state information: |
| 3.A. Extract the data out from entire data via the information of data range. |
| 3.B. Construct a new state instance if the state information is the first occurrence |
| for the current grouping key. |
| 3.C. Leverage the existing state instance if it is already available for the current |
| grouping key. (Meaning it's not the first occurrence.) |
| 3.D. Remove the cache of state instance if the state information denotes the data is |
| the last chunk for current grouping key. |
| |
| This deserialization logic assumes that Arrow RecordBatches contain the data with the |
| ordering that data chunks for same grouping key will appear sequentially. |
| |
| This function must avoid materializing multiple Arrow RecordBatches into memory at the |
| same time. And data chunks from the same grouping key should appear sequentially, to |
| further group them based on state instance (same state instance will be produced for |
| same grouping key). |
| """ |
| |
| state_for_current_group = None |
| |
| for batch in batches: |
| batch_schema = batch.schema |
| data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) |
| state_schema = pa.schema( |
| [ |
| batch_schema[-1], |
| ] |
| ) |
| |
| batch_columns = batch.columns |
| data_columns = batch_columns[0:-1] |
| state_column = batch_columns[-1] |
| |
| data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) |
| state_batch = pa.RecordBatch.from_arrays( |
| [ |
| state_column, |
| ], |
| schema=state_schema, |
| ) |
| |
| state_pandas = ArrowBatchTransformer.to_pandas( |
| state_batch, |
| timezone=self._timezone, |
| schema=None, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| )[0] |
| |
| for state_idx in range(0, len(state_pandas)): |
| state_info_col = state_pandas.iloc[state_idx] |
| |
| if not state_info_col: |
| # no more data with grouping key + state |
| break |
| |
| data_start_offset = state_info_col["startOffset"] |
| num_data_rows = state_info_col["numRows"] |
| is_last_chunk = state_info_col["isLastChunk"] |
| |
| if state_for_current_group: |
| # use the state, we already have state for same group and there should be |
| # some data in same group being processed earlier |
| state = state_for_current_group |
| else: |
| # there is no state being stored for same group, construct one |
| state = construct_state(state_info_col) |
| |
| if is_last_chunk: |
| # discard the state being cached for same group |
| state_for_current_group = None |
| elif not state_for_current_group: |
| # there's no cached state but expected to have additional data in same group |
| # cache the current state |
| state_for_current_group = state |
| |
| data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) |
| data_pandas = ArrowBatchTransformer.to_pandas( |
| data_batch_for_group, |
| timezone=self._timezone, |
| schema=None, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ) |
| |
| # state info |
| yield ( |
| data_pandas, |
| state, |
| ) |
| |
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) |
| |
| data_state_generator = gen_data_and_state(_batches) |
| |
| # state will be same object for same grouping key |
| for _state, _data in groupby(data_state_generator, key=lambda x: x[1]): |
| yield ( |
| _data, |
| _state, |
| ) |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Read through an iterator of (iterator of pandas DataFrame, state), serialize them to Arrow |
| RecordBatches, and write batches to stream. |
| """ |
| |
| import pandas as pd |
| |
| def construct_state_pdf(state): |
| """ |
| Construct a pandas DataFrame from the state instance. |
| """ |
| |
| state_properties = state.json().encode("utf-8") |
| state_key_row_as_binary = state._keyAsUnsafe |
| if state.exists: |
| state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) |
| else: |
| state_object = None |
| state_old_timeout_timestamp = state.oldTimeoutTimestamp |
| |
| state_dict = { |
| "properties": [ |
| state_properties, |
| ], |
| "keyRowAsUnsafe": [ |
| state_key_row_as_binary, |
| ], |
| "object": [ |
| state_object, |
| ], |
| "oldTimeoutTimestamp": [ |
| state_old_timeout_timestamp, |
| ], |
| } |
| |
| return pd.DataFrame.from_dict(state_dict) |
| |
| def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): |
| """ |
| Construct a new Arrow RecordBatch based on output pandas DataFrames and states. Each |
| one matches to the single struct field for Arrow schema. We also need an extra one to |
| indicate array length for data and state, so the return value of Arrow RecordBatch will |
| have schema with three fields, in `count`, `data`, `state` order. |
| (Readers are expected to access the field via position rather than the name. We do |
| not guarantee the name of the field.) |
| |
| Note that Arrow RecordBatch requires all columns to have all same number of rows, |
| hence this function inserts empty data for count/state/data with less elements to |
| compensate. |
| """ |
| |
| max_data_cnt = max(1, max(pdf_data_cnt, state_data_cnt)) |
| |
| # We only use the first row in the count column, and fill other rows to be the same |
| # value, hoping it is more friendly for compression, in case it is needed. |
| count_dict = { |
| "dataCount": [pdf_data_cnt] * max_data_cnt, |
| "stateCount": [state_data_cnt] * max_data_cnt, |
| } |
| count_pdf = pd.DataFrame.from_dict(count_dict) |
| |
| empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt |
| empty_row_cnt_in_state = max_data_cnt - state_data_cnt |
| |
| empty_rows_pdf = pd.DataFrame( |
| dict.fromkeys(pdf_schema.names), |
| index=[x for x in range(0, empty_row_cnt_in_data)], |
| ) |
| empty_rows_state = pd.DataFrame( |
| columns=["properties", "keyRowAsUnsafe", "object", "oldTimeoutTimestamp"], |
| index=[x for x in range(0, empty_row_cnt_in_state)], |
| ) |
| |
| pdfs.append(empty_rows_pdf) |
| state_pdfs.append(empty_rows_state) |
| |
| merged_pdf = pd.concat(pdfs, ignore_index=True) |
| merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) |
| |
| # Create batch from list of DataFrames, each wrapped as a StructArray. |
| # Schema fields map to: _0=count, _1=output data, _2=state data |
| # (types defined in __init__: result_count_df_type, pdf_schema, result_state_df_type) |
| data = [count_pdf, merged_pdf, merged_state_pdf] |
| schema = StructType( |
| [ |
| StructField("_0", self.result_count_df_type), |
| StructField("_1", pdf_schema), |
| StructField("_2", self.result_state_df_type), |
| ] |
| ) |
| return PandasToArrowConversion.convert( |
| data, |
| schema, |
| timezone=self._timezone, |
| safecheck=self._safecheck, |
| arrow_cast=self._arrow_cast, |
| assign_cols_by_name=self._assign_cols_by_name, |
| int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled, |
| ) |
| |
| def serialize_batches(): |
| """ |
| Read through an iterator of (iterator of pandas DataFrame, state), and serialize them |
| to Arrow RecordBatches. |
| |
| This function does batching on constructing the Arrow RecordBatch; a batch will be |
| serialized to the Arrow RecordBatch when the total number of records exceeds the |
| configured threshold. |
| """ |
| # a set of variables for the state of current batch which will be converted to Arrow |
| # RecordBatch. |
| pdfs = [] |
| state_pdfs = [] |
| pdf_data_cnt = 0 |
| state_data_cnt = 0 |
| |
| return_schema = None |
| |
| for data in iterator: |
| # data represents the result of each call of user function |
| packaged_result = data[0] |
| |
| # There are two results from the call of user function: |
| # 1) iterator of pandas DataFrame (output) |
| # 2) updated state instance |
| pdf_iter = packaged_result[0][0] |
| state = packaged_result[0][1] |
| |
| # This is static and won't change across batches. |
| return_schema = packaged_result[1] |
| |
| for pdf in pdf_iter: |
| # We ignore empty pandas DataFrame. |
| if len(pdf) > 0: |
| pdf_data_cnt += len(pdf) |
| pdfs.append(pdf) |
| |
| # If the total number of records in current batch exceeds the configured |
| # threshold, time to construct the Arrow RecordBatch from the batch. |
| if pdf_data_cnt > self.arrow_max_records_per_batch: |
| batch = construct_record_batch( |
| pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt |
| ) |
| |
| # Reset the variables to start with new batch for further data. |
| pdfs = [] |
| state_pdfs = [] |
| pdf_data_cnt = 0 |
| state_data_cnt = 0 |
| |
| yield batch |
| |
| # This has to be performed 'after' evaluating all elements in iterator, so that |
| # the user function has been completed and the state is guaranteed to be updated. |
| state_pdf = construct_state_pdf(state) |
| |
| state_pdfs.append(state_pdf) |
| state_data_cnt += 1 |
| |
| # processed all output, but current batch may not be flushed yet. |
| if pdf_data_cnt > 0 or state_data_cnt > 0: |
| batch = construct_record_batch( |
| pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt |
| ) |
| |
| yield batch |
| |
| batches = self._write_stream_start(serialize_batches(), stream) |
| return ArrowStreamSerializer.dump_stream(self, batches, stream) |
| |
| |
| class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): |
| """ |
| Serializer used by Python worker to evaluate UDF for |
| :meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`. |
| |
| Parameters |
| ---------- |
| timezone : str |
| A timezone to respect when handling timestamp values |
| safecheck : bool |
| If True, conversion from Arrow to Pandas checks for overflow/truncation |
| assign_cols_by_name : bool |
| If True, then Pandas DataFrames will get columns by name |
| arrow_max_records_per_batch : int |
| Limit of the number of records that can be written to a single ArrowRecordBatch in memory. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| prefer_int_ext_dtype, |
| arrow_max_records_per_batch, |
| arrow_max_bytes_per_batch, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| assign_cols_by_name=assign_cols_by_name, |
| df_for_struct=False, |
| struct_in_pandas="dict", |
| ndarray_as_list=False, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| arrow_cast=True, |
| input_type=None, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| ) |
| self.arrow_max_records_per_batch = ( |
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 |
| ) |
| self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch |
| self.key_offsets = None |
| self.average_arrow_row_size = 0 |
| self.total_bytes = 0 |
| self.total_rows = 0 |
| |
| def _update_batch_size_stats(self, batch): |
| """ |
| Update batch size statistics for adaptive batching. |
| """ |
| # Short circuit batch size calculation if the batch size is |
| # unlimited as computing batch size is computationally expensive. |
| if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: |
| batch_bytes = sum( |
| buf.size for col in batch.columns for buf in col.buffers() if buf is not None |
| ) |
| self.total_bytes += batch_bytes |
| self.total_rows += batch.num_rows |
| self.average_arrow_row_size = self.total_bytes / self.total_rows |
| |
| def load_stream(self, stream): |
| """ |
| Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and |
| convert the data into Rows. |
| |
| Please refer the doc of inner function `generate_data_batches` for more details how |
| this function works in overall. |
| """ |
| import pandas as pd |
| from pyspark.sql.streaming.stateful_processor_util import ( |
| TransformWithStateInPandasFuncMode, |
| ) |
| |
| def generate_data_batches(batches): |
| """ |
| Deserialize ArrowRecordBatches and return a generator of Rows. |
| |
| The deserialization logic assumes that Arrow RecordBatches contain the data with the |
| ordering that data chunks for same grouping key will appear sequentially. |
| |
| This function must avoid materializing multiple Arrow RecordBatches into memory at the |
| same time. And data chunks from the same grouping key should appear sequentially. |
| """ |
| |
| def row_stream(): |
| for batch in batches: |
| self._update_batch_size_stats(batch) |
| data_pandas = ArrowBatchTransformer.to_pandas( |
| batch, |
| timezone=self._timezone, |
| schema=self._input_type, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ) |
| for row in pd.concat(data_pandas, axis=1).itertuples(index=False): |
| batch_key = tuple(row[s] for s in self.key_offsets) |
| yield (batch_key, row) |
| |
| for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): |
| rows = [] |
| for _, row in group_rows: |
| rows.append(row) |
| if ( |
| len(rows) >= self.arrow_max_records_per_batch |
| or len(rows) * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch |
| ): |
| yield (batch_key, pd.DataFrame(rows)) |
| rows = [] |
| if rows: |
| yield (batch_key, pd.DataFrame(rows)) |
| |
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) |
| data_batches = generate_data_batches(_batches) |
| |
| for k, g in groupby(data_batches, key=lambda x: x[0]): |
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) |
| |
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) |
| |
| yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow |
| RecordBatches, and write batches to stream. |
| """ |
| |
| def flatten_iterator(): |
| # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]] |
| for packed in iterator: |
| iter_pdf_with_type = packed[0] |
| iter_pdf = iter_pdf_with_type[0] |
| pdf_type = iter_pdf_with_type[1] |
| for pdf in iter_pdf: |
| yield [(pdf, pdf_type)] |
| |
| super().dump_stream(flatten_iterator(), stream) |
| |
| |
| class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): |
| """ |
| Serializer used by Python worker to evaluate UDF for |
| :meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`. |
| Parameters |
| ---------- |
| Same as input parameters in TransformWithStateInPandasSerializer. |
| """ |
| |
| def __init__( |
| self, |
| *, |
| timezone, |
| safecheck, |
| assign_cols_by_name, |
| prefer_int_ext_dtype, |
| arrow_max_records_per_batch, |
| arrow_max_bytes_per_batch, |
| int_to_decimal_coercion_enabled, |
| ): |
| super().__init__( |
| timezone=timezone, |
| safecheck=safecheck, |
| assign_cols_by_name=assign_cols_by_name, |
| prefer_int_ext_dtype=prefer_int_ext_dtype, |
| arrow_max_records_per_batch=arrow_max_records_per_batch, |
| arrow_max_bytes_per_batch=arrow_max_bytes_per_batch, |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, |
| ) |
| self.init_key_offsets = None |
| |
| def load_stream(self, stream): |
| import pyarrow as pa |
| import pandas as pd |
| from pyspark.sql.streaming.stateful_processor_util import ( |
| TransformWithStateInPandasFuncMode, |
| ) |
| |
| def generate_data_batches(batches): |
| """ |
| Deserialize ArrowRecordBatches and return a generator of pandas.Series list. |
| |
| The deserialization logic assumes that Arrow RecordBatches contain the data with the |
| ordering that data chunks for same grouping key will appear sequentially. |
| See `TransformWithStateInPandasPythonInitialStateRunner` for arrow batch schema sent |
| from JVM. |
| This function flatten the columns of input rows and initial state rows and feed them |
| into the data generator. |
| """ |
| |
| def flatten_columns(cur_batch, col_name): |
| state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) |
| |
| # Check if the entire column is null |
| if state_column.null_count == len(state_column): |
| # Return empty table with no columns |
| return pa.Table.from_arrays([], names=[]) |
| |
| state_field_names = [ |
| state_column.type[i].name for i in range(state_column.type.num_fields) |
| ] |
| state_field_arrays = [ |
| state_column.field(i) for i in range(state_column.type.num_fields) |
| ] |
| table_from_fields = pa.Table.from_arrays( |
| state_field_arrays, names=state_field_names |
| ) |
| return table_from_fields |
| |
| """ |
| The arrow batch is written in the schema: |
| schema: StructType = new StructType() |
| .add("inputData", dataSchema) |
| .add("initState", initStateSchema) |
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python |
| data generator. Rows in the same batch may have different grouping keys, |
| but each batch will have either init_data or input_data, not mix. |
| """ |
| |
| def to_pandas(table): |
| return ArrowBatchTransformer.to_pandas( |
| table, |
| timezone=self._timezone, |
| schema=self._input_type, |
| struct_in_pandas=self._struct_in_pandas, |
| ndarray_as_list=self._ndarray_as_list, |
| prefer_int_ext_dtype=self._prefer_int_ext_dtype, |
| df_for_struct=self._df_for_struct, |
| ) |
| |
| def row_stream(): |
| for batch in batches: |
| self._update_batch_size_stats(batch) |
| |
| data_table = flatten_columns(batch, "inputData") |
| init_table = flatten_columns(batch, "initState") |
| |
| # Check column count - empty table has no columns |
| has_data = data_table.num_columns > 0 |
| has_init = init_table.num_columns > 0 |
| |
| assert not (has_data and has_init) |
| |
| if has_data: |
| for row in pd.concat(to_pandas(data_table), axis=1).itertuples(index=False): |
| batch_key = tuple(row[s] for s in self.key_offsets) |
| yield (batch_key, row, None) |
| elif has_init: |
| for row in pd.concat(to_pandas(init_table), axis=1).itertuples(index=False): |
| batch_key = tuple(row[s] for s in self.init_key_offsets) |
| yield (batch_key, None, row) |
| |
| EMPTY_DATAFRAME = pd.DataFrame() |
| for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): |
| rows = [] |
| init_state_rows = [] |
| for _, row, init_state_row in group_rows: |
| if row is not None: |
| rows.append(row) |
| if init_state_row is not None: |
| init_state_rows.append(init_state_row) |
| |
| total_len = len(rows) + len(init_state_rows) |
| if ( |
| total_len >= self.arrow_max_records_per_batch |
| or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch |
| ): |
| yield ( |
| batch_key, |
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), |
| ( |
| pd.DataFrame(init_state_rows) |
| if len(init_state_rows) > 0 |
| else EMPTY_DATAFRAME.copy() |
| ), |
| ) |
| rows = [] |
| init_state_rows = [] |
| if rows or init_state_rows: |
| yield ( |
| batch_key, |
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), |
| ( |
| pd.DataFrame(init_state_rows) |
| if len(init_state_rows) > 0 |
| else EMPTY_DATAFRAME.copy() |
| ), |
| ) |
| |
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) |
| data_batches = generate_data_batches(_batches) |
| |
| for k, g in groupby(data_batches, key=lambda x: x[0]): |
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) |
| |
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) |
| |
| yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) |
| |
| |
| class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): |
| """ |
| Serializer used by Python worker to evaluate UDF for |
| :meth:`pyspark.sql.GroupedData.transformWithState`. |
| |
| Parameters |
| ---------- |
| arrow_max_records_per_batch : int |
| Limit of the number of records that can be written to a single ArrowRecordBatch in memory. |
| """ |
| |
| def __init__(self, *, arrow_max_records_per_batch): |
| super().__init__() |
| self.arrow_max_records_per_batch = ( |
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 |
| ) |
| self.key_offsets = None |
| |
| def load_stream(self, stream): |
| """ |
| Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunks, |
| and convert the data into a list of pandas.Series. |
| |
| Please refer the doc of inner function `generate_data_batches` for more details how |
| this function works in overall. |
| """ |
| from pyspark.sql.streaming.stateful_processor_util import ( |
| TransformWithStateInPandasFuncMode, |
| ) |
| import itertools |
| |
| def generate_data_batches(batches): |
| """ |
| Deserialize ArrowRecordBatches and return a generator of Row. |
| |
| The deserialization logic assumes that Arrow RecordBatches contain the data with the |
| ordering that data chunks for same grouping key will appear sequentially. |
| |
| This function must avoid materializing multiple Arrow RecordBatches into memory at the |
| same time. And data chunks from the same grouping key should appear sequentially. |
| """ |
| for batch in batches: |
| DataRow = Row(*batch.schema.names) |
| |
| # Iterate row by row without converting the whole batch |
| num_cols = batch.num_columns |
| for row_idx in range(batch.num_rows): |
| # build the key for this row |
| row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets) |
| row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) |
| yield row_key, row |
| |
| _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) |
| data_batches = generate_data_batches(_batches) |
| |
| for k, g in groupby(data_batches, key=lambda x: x[0]): |
| chained = itertools.chain(g) |
| chained_values = map(lambda x: x[1], chained) |
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) |
| |
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) |
| |
| yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) |
| |
| def dump_stream(self, iterator, stream): |
| """ |
| Read through an iterator of (iterator of Row), serialize them to Arrow |
| RecordBatches, and write batches to stream. |
| """ |
| import pyarrow as pa |
| |
| from pyspark.sql.pandas.types import to_arrow_type |
| |
| def flatten_iterator(): |
| # iterator: iter[list[(iter[Row], spark_type)]] |
| for packed in iterator: |
| iter_row_with_type = packed[0] |
| iter_row = iter_row_with_type[0] |
| spark_type = iter_row_with_type[1] |
| |
| # Convert spark type to arrow type |
| # TODO: WE need to make this configurable, currently using default values. |
| arrow_type = to_arrow_type( |
| spark_type, |
| timezone="UTC", |
| prefers_large_types=False, |
| ) |
| |
| rows_as_dict = [] |
| for row in iter_row: |
| row_as_dict = row.asDict(True) |
| rows_as_dict.append(row_as_dict) |
| |
| pdf_schema = pa.schema(list(arrow_type)) |
| record_batch = pa.RecordBatch.from_pylist(rows_as_dict, schema=pdf_schema) |
| |
| yield (record_batch, arrow_type) |
| |
| return ArrowStreamUDFSerializer.dump_stream(self, flatten_iterator(), stream) |
| |
| |
| class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer): |
| """ |
| Serializer used by Python worker to evaluate UDF for |
| :meth:`pyspark.sql.GroupedData.transformWithStateInPySparkRowInitStateSerializer`. |
| Parameters |
| ---------- |
| Same as input parameters in TransformWithStateInPySparkRowSerializer. |
| """ |
| |
| def __init__(self, *, arrow_max_records_per_batch): |
| super().__init__(arrow_max_records_per_batch=arrow_max_records_per_batch) |
| self.init_key_offsets = None |
| |
| def load_stream(self, stream): |
| import pyarrow as pa |
| from pyspark.sql.streaming.stateful_processor_util import ( |
| TransformWithStateInPandasFuncMode, |
| ) |
| |
| def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]: |
| """ |
| Deserialize ArrowRecordBatches and return a generator of Row. |
| The deserialization logic assumes that Arrow RecordBatches contain the data with the |
| ordering that data chunks for same grouping key will appear sequentially. |
| See `TransformWithStateInPySparkPythonInitialStateRunner` for arrow batch schema sent |
| from JVM. |
| This function flattens the columns of input rows and initial state rows and feed them |
| into the data generator. |
| """ |
| |
| def extract_rows( |
| cur_batch, col_name, key_offsets |
| ) -> Optional[Iterator[Tuple[Any, Any]]]: |
| data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) |
| |
| # Check if the entire column is null |
| if data_column.null_count == len(data_column): |
| return None |
| |
| data_field_names = [ |
| data_column.type[i].name for i in range(data_column.type.num_fields) |
| ] |
| data_field_arrays = [ |
| data_column.field(i) for i in range(data_column.type.num_fields) |
| ] |
| |
| DataRow = Row(*data_field_names) |
| |
| table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) |
| |
| if table.num_rows == 0: |
| return None |
| |
| def row_iterator(): |
| for row_idx in range(table.num_rows): |
| key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) |
| row = DataRow( |
| *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) |
| ) |
| yield (key, row) |
| |
| return row_iterator() |
| |
| """ |
| The arrow batch is written in the schema: |
| schema: StructType = new StructType() |
| .add("inputData", dataSchema) |
| .add("initState", initStateSchema) |
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python |
| data generator. Each batch will have either init_data or input_data, not mix. |
| """ |
| for batch in batches: |
| # Detect which column has data - each batch contains only one type |
| input_result = extract_rows(batch, "inputData", self.key_offsets) |
| init_result = extract_rows(batch, "initState", self.init_key_offsets) |
| |
| assert not (input_result is not None and init_result is not None) |
| |
| if input_result is not None: |
| for key, input_data_row in input_result: |
| yield (key, input_data_row, None) |
| elif init_result is not None: |
| for key, init_state_row in init_result: |
| yield (key, None, init_state_row) |
| |
| _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) |
| data_batches = generate_data_batches(_batches) |
| |
| for k, g in groupby(data_batches, key=lambda x: x[0]): |
| input_rows = [] |
| init_rows = [] |
| |
| for batch_key, input_row, init_row in g: |
| if input_row is not None: |
| input_rows.append(input_row) |
| if init_row is not None: |
| init_rows.append(init_row) |
| |
| total_len = len(input_rows) + len(init_rows) |
| if total_len >= self.arrow_max_records_per_batch: |
| ret_tuple = (iter(input_rows), iter(init_rows)) |
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) |
| input_rows = [] |
| init_rows = [] |
| |
| if input_rows or init_rows: |
| ret_tuple = (iter(input_rows), iter(init_rows)) |
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) |
| |
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) |
| |
| yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) |