| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import asyncio |
| import typing |
| from datetime import timedelta |
| |
| import statefun.context |
| from statefun.core import parse_typename, ValueSpec, SdkAddress |
| from statefun.messages import Message, EgressMessage |
| from statefun.statefun_builder import StatefulFunctions, StatefulFunction |
| |
| # generated function protocol |
| from statefun.request_reply_pb2 import ToFunction, FromFunction, Address, TypedValue |
| from statefun.storage import resolve, Cell |
| |
| from dataclasses import dataclass |
| |
| |
| @dataclass |
| class DelayedMessage: |
| is_cancellation: bool = None |
| duration: int = None |
| message: Message = None, |
| cancellation_token: str = None |
| |
| |
| class UserFacingContext(statefun.context.Context): |
| __slots__ = ( |
| "_self_address", "_outgoing_messages", "_outgoing_delayed_messages", "_outgoing_egress_messages", "_storage", |
| "_caller") |
| |
| def __init__(self, address, storage): |
| self._self_address = address |
| self._outgoing_messages = [] |
| self._outgoing_delayed_messages: typing.List[DelayedMessage] = [] |
| self._outgoing_egress_messages = [] |
| self._storage = storage |
| self._caller = None |
| |
| @property |
| def address(self) -> SdkAddress: |
| """ |
| |
| :return: the address of the currently executing function. the address is of the form (typename, id) |
| """ |
| return self._self_address |
| |
| @property |
| def storage(self): |
| return self._storage |
| |
| @property |
| def caller(self): |
| return self._caller |
| |
| def send(self, message: Message): |
| """ |
| Send a message to a function. |
| |
| :param message: a message to send. |
| """ |
| self._outgoing_messages.append(message) |
| |
| def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""): |
| """ |
| Send a message to a target function after a specified delay. |
| |
| :param duration: the amount of time to wait before sending this message out. |
| :param message: the message to send. |
| :param cancellation_token: an optional cancellation token to associate with this message. |
| """ |
| ms = int(duration.total_seconds() * 1000.0) |
| record = DelayedMessage(is_cancellation=False, duration=ms, message=message, |
| cancellation_token=cancellation_token) |
| self._outgoing_delayed_messages.append(record) |
| |
| def cancel_delayed_message(self, cancellation_token: str): |
| """ |
| Cancel a delayed message (message that was sent using send_after) with a given token. |
| |
| Please note that this is a best-effort operation, since the message might have been already delivered. |
| If the message was delivered, this is a no-op operation. |
| """ |
| record = DelayedMessage(is_cancellation=True, cancellation_token=cancellation_token) |
| self._outgoing_delayed_messages.append(record) |
| |
| def send_egress(self, message: EgressMessage): |
| """ |
| Send a message to an egress. |
| |
| :param message: the EgressMessage to send. |
| """ |
| self._outgoing_egress_messages.append(message) |
| |
| |
| # ------------------------------------------------------------------------------------------------------------------- |
| # Protobuf Helpers |
| # ------------------------------------------------------------------------------------------------------------------- |
| |
| def sdk_address_from_pb(addr: Address) -> typing.Optional[SdkAddress]: |
| if not addr: |
| return None |
| return SdkAddress(namespace=addr.namespace, |
| name=addr.type, |
| id=addr.id, |
| typename=f"{addr.namespace}/{addr.type}") |
| |
| |
| # noinspection PyProtectedMember |
| def collect_success(ctx: UserFacingContext) -> FromFunction: |
| pb_from_function = FromFunction() |
| pb_invocation_result = pb_from_function.invocation_result |
| collect_messages(ctx._outgoing_messages, pb_invocation_result) |
| collect_delayed(ctx._outgoing_delayed_messages, pb_invocation_result) |
| collect_egress(ctx._outgoing_egress_messages, pb_invocation_result) |
| collect_mutations(ctx._storage._cells, pb_invocation_result) |
| return pb_from_function |
| |
| |
| def collect_failure(missing_state_specs: typing.List[ValueSpec]) -> FromFunction: |
| pb_from_function = FromFunction() |
| incomplete_context = pb_from_function.incomplete_invocation_context |
| missing_values = incomplete_context.missing_values |
| for state_spec in missing_state_specs: |
| missing_value = missing_values.add() |
| missing_value.state_name = state_spec.name |
| missing_value.type_typename = state_spec.type.typename |
| |
| protocol_expiration_spec = FromFunction.ExpirationSpec() |
| if not state_spec.after_write and not state_spec.after_call: |
| protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.NONE |
| else: |
| protocol_expiration_spec.expire_after_millis = state_spec.duration |
| if state_spec.after_call: |
| protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.AFTER_INVOKE |
| elif state_spec.after_write: |
| protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.AFTER_WRITE |
| else: |
| raise ValueError("Unexpected state expiration mode.") |
| missing_value.expiration_spec.CopyFrom(protocol_expiration_spec) |
| return pb_from_function |
| |
| |
| def collect_messages(messages: typing.List[Message], pb_invocation_result): |
| pb_outgoing_messages = pb_invocation_result.outgoing_messages |
| for message in messages: |
| outgoing = pb_outgoing_messages.add() |
| |
| namespace, type = parse_typename(message.target_typename) |
| outgoing.target.namespace = namespace |
| outgoing.target.type = type |
| outgoing.target.id = message.target_id |
| outgoing.argument.CopyFrom(message.typed_value) |
| |
| |
| def collect_delayed(delayed_messages: typing.List[DelayedMessage], invocation_result): |
| delayed_invocations = invocation_result.delayed_invocations |
| for delayed_message in delayed_messages: |
| outgoing = delayed_invocations.add() |
| |
| if delayed_message.is_cancellation: |
| # handle cancellation |
| outgoing.cancellation_token = delayed_message.cancellation_token |
| outgoing.is_cancellation_request = True |
| else: |
| message = delayed_message.message |
| namespace, type = parse_typename(message.target_typename) |
| |
| outgoing.target.namespace = namespace |
| outgoing.target.type = type |
| outgoing.target.id = message.target_id |
| outgoing.delay_in_ms = delayed_message.duration |
| outgoing.argument.CopyFrom(message.typed_value) |
| if delayed_message.cancellation_token is not None: |
| outgoing.cancellation_token = delayed_message.cancellation_token |
| |
| |
| def collect_cancellations(tokens: typing.List[str], invocation_result): |
| outgoing_cancellations = invocation_result.outgoing_delay_cancellations |
| for token in tokens: |
| if token: |
| delay_cancelltion = outgoing_cancellations.add() |
| delay_cancelltion.cancellation_token = token |
| |
| |
| def collect_egress(egresses: typing.List[EgressMessage], invocation_result): |
| outgoing_egresses = invocation_result.outgoing_egresses |
| for message in egresses: |
| outgoing = outgoing_egresses.add() |
| |
| namespace, type = parse_typename(message.typename) |
| outgoing.egress_namespace = namespace |
| outgoing.egress_type = type |
| outgoing.argument.CopyFrom(message.typed_value) |
| |
| |
| def collect_mutations(cells: typing.Dict[str, Cell], invocation_result): |
| for key, cell in cells.items(): |
| if not cell.dirty: |
| continue |
| mutation = invocation_result.state_mutations.add() |
| mutation.state_name = key |
| val: typing.Optional[TypedValue] = cell.typed_value |
| if val is None: |
| # it is deleted. |
| mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('DELETE') |
| else: |
| mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('MODIFY') |
| mutation.state_value.CopyFrom(val) |
| |
| |
| # -------------------------------------------------------------------------------------------------------------------- |
| # The main Request Reply Handler. |
| # -------------------------------------------------------------------------------------------------------------------- |
| |
| class RequestReplyHandler(object): |
| def __init__(self, functions: StatefulFunctions): |
| if not functions: |
| raise ValueError("functions must be provided.") |
| self.functions = functions |
| |
| def handle_sync(self, request_bytes: typing.Union[str, bytes, bytearray]) -> bytes: |
| return asyncio.run(self.handle_async(request_bytes)) |
| |
| async def handle_async(self, request_bytes: typing.Union[str, bytes, bytearray]) -> bytes: |
| # parse |
| pb_to_function = ToFunction() |
| pb_to_function.ParseFromString(request_bytes) |
| # target address |
| pb_target_address = pb_to_function.invocation.target |
| sdk_address = sdk_address_from_pb(pb_target_address) |
| # target stateful function |
| target_fn: StatefulFunction = self.functions.for_typename(sdk_address.typename) |
| if not target_fn: |
| raise ValueError(f"Unable to find a function of type {sdk_address.typename}") |
| # resolve state |
| res = resolve(target_fn.storage_spec, sdk_address.typename, pb_to_function.invocation.state) |
| if res.missing_specs: |
| pb_from_function = collect_failure(res.missing_specs) |
| return pb_from_function.SerializeToString() |
| # invoke the batch |
| ctx = UserFacingContext(sdk_address, res.storage) |
| fun = target_fn.fun |
| pb_batch = pb_to_function.invocation.invocations |
| if target_fn.is_async: |
| for pb_invocation in pb_batch: |
| msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, |
| typed_value=pb_invocation.argument) |
| ctx._caller = sdk_address_from_pb(pb_invocation.caller) |
| # await for an async function to complete. |
| # noinspection PyUnresolvedReferences |
| await fun(ctx, msg) |
| else: |
| for pb_invocation in pb_batch: |
| msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, |
| typed_value=pb_invocation.argument) |
| ctx._caller = sdk_address_from_pb(pb_invocation.caller) |
| # we need to call the function directly ¯\_(ツ)_/¯ |
| fun(ctx, msg) |
| # collect the results |
| pb_from_function = collect_success(ctx) |
| return pb_from_function.SerializeToString() |