blob: c926f45c2c15d89d9278af966c42f2acd5174e64 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
import itertools
from typing import Any, Iterator, Optional, TYPE_CHECKING, Union
from pyspark.sql.streaming.stateful_processor_api_client import (
StatefulProcessorApiClient,
StatefulProcessorHandleState,
)
from pyspark.sql.streaming.stateful_processor import (
ExpiredTimerInfo,
StatefulProcessor,
StatefulProcessorHandle,
TimerValues,
)
from pyspark.sql.streaming.stateful_processor_api_client import ExpiredTimerIterator
from pyspark.sql.types import Row
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
# This file places the utilities for transformWithState in PySpark (Row, and Pandas); we have
# a separate file to avoid putting internal classes to the stateful_processor.py file which
# contains public APIs.
class TransformWithStateInPandasFuncMode(Enum):
"""
Internal mode for python worker UDF mode for transformWithState in PySpark; external mode are
in `StatefulProcessorHandleState` for public use purposes.
NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect.
"""
PROCESS_DATA = 1
PROCESS_TIMER = 2
COMPLETE = 3
PRE_INIT = 4
class TransformWithStateInPandasUdfUtils:
"""
Internal Utility class used for python worker UDF for transformWithState in PySpark. This class
is shared for both classic and spark connect mode.
NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect.
"""
def __init__(self, stateful_processor: StatefulProcessor, time_mode: str):
self._stateful_processor = stateful_processor
self._time_mode = time_mode
def transformWithStateUDF(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
return self._handle_pre_init(stateful_processor_api_client)
handle = StatefulProcessorHandle(stateful_processor_api_client)
if stateful_processor_api_client.handle_state == StatefulProcessorHandleState.CREATED:
self._stateful_processor.init(handle)
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = self._handle_expired_timers(stateful_processor_api_client)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.TIMER_PROCESSED
)
stateful_processor_api_client.remove_implicit_key()
self._stateful_processor.close()
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
result = self._handle_data_rows(stateful_processor_api_client, key, input_rows)
return result
def transformWithStateWithInitStateUDF(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
initial_states: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None,
) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
"""
UDF for TWS operator with non-empty initial states. Possible input combinations
of inputRows and initialStates iterator:
- Both `inputRows` and `initialStates` are non-empty. Both input rows and initial
states contains the grouping key and data.
- `InitialStates` is non-empty, while `inputRows` is empty. Only initial states
contains the grouping key and data, and it is first batch.
- `initialStates` is empty, while `inputRows` is non-empty. Only inputRows contains the
grouping key and data, and it is first batch.
- `initialStates` is None, while `inputRows` is not empty. This is not first batch.
`initialStates` is initialized to the positional value as None.
"""
if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
return self._handle_pre_init(stateful_processor_api_client)
handle = StatefulProcessorHandle(stateful_processor_api_client)
if stateful_processor_api_client.handle_state == StatefulProcessorHandleState.CREATED:
self._stateful_processor.init(handle)
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = self._handle_expired_timers(stateful_processor_api_client)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
stateful_processor_api_client.remove_implicit_key()
self._stateful_processor.close()
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
self._time_mode
)
# only process initial state if first batch and initial state is not None
if initial_states is not None:
for cur_initial_state in initial_states:
stateful_processor_api_client.set_implicit_key(key)
self._stateful_processor.handleInitialState(
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)
)
# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
input_rows_empty = False
try:
first = next(input_rows)
except StopIteration:
input_rows_empty = True
else:
input_rows = itertools.chain([first], input_rows) # type: ignore
if not input_rows_empty:
result = self._handle_data_rows(stateful_processor_api_client, key, input_rows)
else:
result = iter([])
return result
def _handle_pre_init(
self, stateful_processor_api_client: StatefulProcessorApiClient
) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
# Driver handle is different from the handle used on executors;
# On JVM side, we will use `DriverStatefulProcessorHandleImpl` for driver handle which
# will only be used for handling init() and get the state schema on the driver.
driver_handle = StatefulProcessorHandle(stateful_processor_api_client)
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.PRE_INIT)
self._stateful_processor.init(driver_handle)
# This method is used for the driver-side stateful processor after we have collected
# all the necessary schemas. This instance of the DriverStatefulProcessorHandleImpl
# won't be used again on JVM.
self._stateful_processor.close()
# return a dummy result, no return value is needed for pre init
return iter([])
def _handle_data_rows(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
key: Any,
input_rows: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None,
) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
stateful_processor_api_client.set_implicit_key(key)
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
self._time_mode
)
# process with data rows
if input_rows is not None:
data_iter = self._stateful_processor.handleInputRows(
key, input_rows, TimerValues(batch_timestamp, watermark_timestamp)
)
return data_iter
else:
return iter([])
def _handle_expired_timers(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
self._time_mode
)
if self._time_mode.lower() == "processingtime":
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, batch_timestamp)
elif self._time_mode.lower() == "eventtime":
expiry_iter = ExpiredTimerIterator(stateful_processor_api_client, watermark_timestamp)
else:
expiry_iter = iter([]) # type: ignore[assignment]
# process with expiry timers, only timer related rows will be emitted
for key_obj, expiry_timestamp in expiry_iter:
stateful_processor_api_client.set_implicit_key(key_obj)
for pd in self._stateful_processor.handleExpiredTimer(
key=key_obj,
timerValues=TimerValues(batch_timestamp, watermark_timestamp),
expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp),
):
yield pd
stateful_processor_api_client.delete_timer(expiry_timestamp)