blob: b86a609183187ec490c729a58d1b09b3445d0790 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package statefun
import (
// StatefulFunctions is a registry for multiple StatefulFunction's. A RequestReplyHandler
// can be created from the registry that understands how to dispatch
// invocation requests to the registered functions as well as encode
// side-effects (e.g., sending messages to other functions or updating
// values in storage) as the response.
type StatefulFunctions interface {
// WithSpec registers a StatefulFunctionSpec, which will be
// used to build the runtime function. It returns an error
// if the specification is invalid and the handler
// fails to register the function.
WithSpec(spec StatefulFunctionSpec) error
// AsHandler creates a RequestReplyHandler from the registered
// function specs.
AsHandler() RequestReplyHandler
// The RequestReplyHandler processes messages
// from the runtime, invokes functions, and encodes
// side effects. The handler implements http.Handler
// so it can easily be embedded in standard Go server
// frameworks.
type RequestReplyHandler interface {
// Invoke method provides compliance with AWS Lambda handler
Invoke(ctx context.Context, payload []byte) ([]byte, error)
// StatefulFunctionsBuilder creates a new StatefulFunctions registry.
func StatefulFunctionsBuilder() StatefulFunctions {
return &handler{
module: map[TypeName]StatefulFunction{},
stateSpecs: map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec{},
type handler struct {
module map[TypeName]StatefulFunction
stateSpecs map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec
func (h *handler) WithSpec(spec StatefulFunctionSpec) error {
log.Printf("registering Stateful Function %v\n", spec.FunctionType)
if _, exists := h.module[spec.FunctionType]; exists {
err := fmt.Errorf("failed to register Stateful Function %s, there is already a spec registered under that type", spec.FunctionType)
return err
if spec.Function == nil {
err := fmt.Errorf("failed to register Stateful Function %s, the Function instance cannot be nil", spec.FunctionType)
return err
valueSpecs := make(map[string]*protocol.FromFunction_PersistedValueSpec, len(spec.States))
for _, state := range spec.States {
log.Printf("registering state specification %v\n", state)
if err := validateValueSpec(state); err != nil {
err := fmt.Errorf("failed to register Stateful Function %s: %w", spec.FunctionType, err)
return err
expiration := &protocol.FromFunction_ExpirationSpec{}
switch state.Expiration.expirationType {
case none:
expiration.Mode = protocol.FromFunction_ExpirationSpec_NONE
case expireAfterWrite:
expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_WRITE
expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds()
case expireAfterCall:
expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_INVOKE
expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds()
valueSpecs[state.Name] = &protocol.FromFunction_PersistedValueSpec{
StateName: state.Name,
ExpirationSpec: expiration,
TypeTypename: state.ValueType.GetTypeName().String(),
h.module[spec.FunctionType] = spec.Function
h.stateSpecs[spec.FunctionType] = valueSpecs
return nil
func (h *handler) AsHandler() RequestReplyHandler {
return h
func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != "POST" {
http.Error(writer, "invalid request method", http.StatusMethodNotAllowed)
contentType := request.Header.Get("Content-type")
if contentType != "" && contentType != "application/octet-stream" {
http.Error(writer, "invalid content type", http.StatusUnsupportedMediaType)
if request.Body == nil || request.ContentLength == 0 {
http.Error(writer, "empty request body", http.StatusBadRequest)
buffer := bytes.Buffer{}
if _, err := buffer.ReadFrom(request.Body); err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
response, err := h.Invoke(request.Context(), buffer.Bytes())
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
_, _ = writer.Write(response)
func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
toFunction := protocol.ToFunction{}
if err := proto.Unmarshal(payload, &toFunction); err != nil {
return nil, fmt.Errorf("failed to unmarshal ToFunction: %w", err)
fromFunction, err := h.invoke(ctx, &toFunction)
if err != nil {
return nil, err
return proto.Marshal(fromFunction)
func (h *handler) invoke(ctx context.Context, toFunction *protocol.ToFunction) (from *protocol.FromFunction, err error) {
batch := toFunction.GetInvocation()
self := addressFromInternal(batch.Target)
function, exists := h.module[self.FunctionType]
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
case error:
err = fmt.Errorf("failed to execute invocation for %s: %w", batch.Target, r)
if !exists {
return nil, fmt.Errorf("unknown function type %s", self.FunctionType)
storageFactory := newStorageFactory(batch, h.stateSpecs[self.FunctionType])
if missing := storageFactory.getMissingSpecs(); missing != nil {
log.Printf("missing state specs for function type %v", self)
for _, spec := range missing {
log.Printf("registering missing specs %v", spec)
return &protocol.FromFunction{
Response: &protocol.FromFunction_IncompleteInvocationContext_{
IncompleteInvocationContext: &protocol.FromFunction_IncompleteInvocationContext{
MissingValues: missing,
}, nil
storage := storageFactory.getStorage()
response := &protocol.FromFunction_InvocationResponse{}
for _, invocation := range batch.Invocations {
select {
case <-ctx.Done():
return nil, ctx.Err()
sContext := statefunContext{
self: self,
storage: storage,
response: response,
var cancel context.CancelFunc
sContext.Context, cancel = context.WithCancel(ctx)
var caller Address
if invocation.Caller != nil {
caller = addressFromInternal(invocation.Caller)
sContext.caller = &caller
msg := Message{
target: batch.Target,
typedValue: invocation.Argument,
err = function.Invoke(&sContext, msg)
if err != nil {
response.StateMutations = storage.getStateMutations()
from = &protocol.FromFunction{
Response: &protocol.FromFunction_InvocationResult{
InvocationResult: response,