| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| from enum import Enum |
| import itertools |
| from typing import Any, Iterator, Optional, TYPE_CHECKING, Union |
| from pyspark.sql.streaming.stateful_processor_api_client import ( |
| StatefulProcessorApiClient, |
| StatefulProcessorHandleState, |
| ) |
| from pyspark.sql.streaming.stateful_processor import ( |
| ExpiredTimerInfo, |
| StatefulProcessor, |
| StatefulProcessorHandle, |
| TimerValues, |
| ) |
| from pyspark.sql.streaming.stateful_processor_api_client import ExpiredTimerIterator |
| from pyspark.sql.types import Row |
| |
| if TYPE_CHECKING: |
| from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike |
| |
| # This file places the utilities for transformWithState in PySpark (Row, and Pandas); we have |
| # a separate file to avoid putting internal classes to the stateful_processor.py file which |
| # contains public APIs. |
| |
| |
| class TransformWithStateInPandasFuncMode(Enum): |
| """ |
| Internal mode for python worker UDF mode for transformWithState in PySpark; external mode are |
| in `StatefulProcessorHandleState` for public use purposes. |
| |
| NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect. |
| """ |
| |
| PROCESS_DATA = 1 |
| PROCESS_TIMER = 2 |
| COMPLETE = 3 |
| PRE_INIT = 4 |
| |
| |
| class TransformWithStateInPandasUdfUtils: |
| """ |
| Internal Utility class used for python worker UDF for transformWithState in PySpark. This class |
| is shared for both classic and spark connect mode. |
| |
| NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect. |
| """ |
| |
| def __init__(self, stateful_processor: StatefulProcessor, time_mode: str): |
| self._stateful_processor = stateful_processor |
| self._time_mode = time_mode |
| |
| def transformWithStateUDF( |
| self, |
| stateful_processor_api_client: StatefulProcessorApiClient, |
| mode: TransformWithStateInPandasFuncMode, |
| key: Any, |
| input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]], |
| ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: |
| if mode == TransformWithStateInPandasFuncMode.PRE_INIT: |
| return self._handle_pre_init(stateful_processor_api_client) |
| |
| handle = StatefulProcessorHandle(stateful_processor_api_client) |
| |
| if stateful_processor_api_client.handle_state == StatefulProcessorHandleState.CREATED: |
| self._stateful_processor.init(handle) |
| stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED) |
| |
| if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: |
| stateful_processor_api_client.set_handle_state( |
| StatefulProcessorHandleState.DATA_PROCESSED |
| ) |
| result = self._handle_expired_timers(stateful_processor_api_client) |
| return result |
| elif mode == TransformWithStateInPandasFuncMode.COMPLETE: |
| stateful_processor_api_client.set_handle_state( |
| StatefulProcessorHandleState.TIMER_PROCESSED |
| ) |
| stateful_processor_api_client.remove_implicit_key() |
| self._stateful_processor.close() |
| stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED) |
| return iter([]) |
| else: |
| # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA |
| result = self._handle_data_rows(stateful_processor_api_client, key, input_rows) |
| return result |
| |
| def transformWithStateWithInitStateUDF( |
| self, |
| stateful_processor_api_client: StatefulProcessorApiClient, |
| mode: TransformWithStateInPandasFuncMode, |
| key: Any, |
| input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]], |
| initial_states: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None, |
| ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: |
| """ |
| UDF for TWS operator with non-empty initial states. Possible input combinations |
| of inputRows and initialStates iterator: |
| - Both `inputRows` and `initialStates` are non-empty. Both input rows and initial |
| states contains the grouping key and data. |
| - `InitialStates` is non-empty, while `inputRows` is empty. Only initial states |
| contains the grouping key and data, and it is first batch. |
| - `initialStates` is empty, while `inputRows` is non-empty. Only inputRows contains the |
| grouping key and data, and it is first batch. |
| - `initialStates` is None, while `inputRows` is not empty. This is not first batch. |
| `initialStates` is initialized to the positional value as None. |
| """ |
| if mode == TransformWithStateInPandasFuncMode.PRE_INIT: |
| return self._handle_pre_init(stateful_processor_api_client) |
| |
| handle = StatefulProcessorHandle(stateful_processor_api_client) |
| |
| if stateful_processor_api_client.handle_state == StatefulProcessorHandleState.CREATED: |
| self._stateful_processor.init(handle) |
| stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED) |
| |
| if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: |
| stateful_processor_api_client.set_handle_state( |
| StatefulProcessorHandleState.DATA_PROCESSED |
| ) |
| result = self._handle_expired_timers(stateful_processor_api_client) |
| return result |
| elif mode == TransformWithStateInPandasFuncMode.COMPLETE: |
| stateful_processor_api_client.remove_implicit_key() |
| self._stateful_processor.close() |
| stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED) |
| return iter([]) |
| else: |
| # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA |
| batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps( |
| self._time_mode |
| ) |
| |
| # only process initial state if first batch and initial state is not None |
| if initial_states is not None: |
| for cur_initial_state in initial_states: |
| stateful_processor_api_client.set_implicit_key(key) |
| self._stateful_processor.handleInitialState( |
| key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp) |
| ) |
| |
| # if we don't have input rows for the given key but only have initial state |
| # for the grouping key, the inputRows iterator could be empty |
| input_rows_empty = False |
| try: |
| first = next(input_rows) |
| except StopIteration: |
| input_rows_empty = True |
| else: |
| input_rows = itertools.chain([first], input_rows) # type: ignore |
| |
| if not input_rows_empty: |
| result = self._handle_data_rows(stateful_processor_api_client, key, input_rows) |
| else: |
| result = iter([]) |
| |
| return result |
| |
| def _handle_pre_init( |
| self, stateful_processor_api_client: StatefulProcessorApiClient |
| ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: |
| # Driver handle is different from the handle used on executors; |
| # On JVM side, we will use `DriverStatefulProcessorHandleImpl` for driver handle which |
| # will only be used for handling init() and get the state schema on the driver. |
| driver_handle = StatefulProcessorHandle(stateful_processor_api_client) |
| stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.PRE_INIT) |
| self._stateful_processor.init(driver_handle) |
| |
| # This method is used for the driver-side stateful processor after we have collected |
| # all the necessary schemas. This instance of the DriverStatefulProcessorHandleImpl |
| # won't be used again on JVM. |
| self._stateful_processor.close() |
| |
| # return a dummy result, no return value is needed for pre init |
| return iter([]) |
| |
| def _handle_data_rows( |
| self, |
| stateful_processor_api_client: StatefulProcessorApiClient, |
| key: Any, |
| input_rows: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None, |
| ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: |
| stateful_processor_api_client.set_implicit_key(key) |
| |
| batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps( |
| self._time_mode |
| ) |
| |
| # process with data rows |
| if input_rows is not None: |
| data_iter = self._stateful_processor.handleInputRows( |
| key, input_rows, TimerValues(batch_timestamp, watermark_timestamp) |
| ) |
| return data_iter |
| else: |
| return iter([]) |
| |
| def _handle_expired_timers( |
| self, |
| stateful_processor_api_client: StatefulProcessorApiClient, |
| ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: |
| batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps( |
| self._time_mode |
| ) |
| |
| if self._time_mode.lower() == "processingtime": |
| expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, batch_timestamp) |
| elif self._time_mode.lower() == "eventtime": |
| expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, watermark_timestamp) |
| else: |
| expiry_iter = iter([]) # type: ignore[assignment] |
| |
| # process with expiry timers, only timer related rows will be emitted |
| for key_obj, expiry_timestamp in expiry_iter: |
| stateful_processor_api_client.set_implicit_key(key_obj) |
| for pd in self._stateful_processor.handleExpiredTimer( |
| key=key_obj, |
| timerValues=TimerValues(batch_timestamp, watermark_timestamp), |
| expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp), |
| ): |
| yield pd |
| stateful_processor_api_client.delete_timer(expiry_timestamp) |