blob: 532a89cf92d2242e400f70ab00043e7f30feb159 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union, Tuple, Optional
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType
from pyspark.errors import PySparkRuntimeError
__all__ = ["ValueStateClient"]
class ValueStateClient:
def __init__(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
schema: Union[StructType, str],
) -> None:
self._stateful_processor_api_client = stateful_processor_api_client
if isinstance(schema, str):
self.schema = self._stateful_processor_api_client._parse_string_schema(schema)
else:
self.schema = schema
def exists(self, state_name: str) -> bool:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
exists_call = stateMessage.Exists()
value_state_call = stateMessage.ValueStateCall(stateName=state_name, exists=exists_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status == 0:
return True
elif status == 2:
# Expect status code is 2 when state variable doesn't have a value.
return False
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(
f"Error checking value state exists: " f"{response_message[1]}"
)
def get(self, state_name: str) -> Optional[Tuple]:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
get_call = stateMessage.Get()
value_state_call = stateMessage.ValueStateCall(stateName=state_name, get=get_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status == 0:
if len(response_message[2]) == 0:
return None
data = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
return tuple(data)
else:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}")
def update(self, state_name: str, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
bytes = self._stateful_processor_api_client._serialize_to_bytes(self.schema, value)
update_call = stateMessage.ValueStateUpdate(value=bytes)
value_state_call = stateMessage.ValueStateCall(
stateName=state_name, valueStateUpdate=update_call
)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}")
def clear(self, state_name: str) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
clear_call = stateMessage.Clear()
value_state_call = stateMessage.ValueStateCall(stateName=state_name, clear=clear_call)
state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error clearing value state: " f"{response_message[1]}")