blob: 3f9f4f6eb54561747826ecfb768a9156a3f565e5 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package statefun
import (
"bytes"
"context"
"fmt"
"log"
"net/http"
"sync"
"github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol"
"google.golang.org/protobuf/proto"
)
// StatefulFunctions is a registry for multiple StatefulFunction's. A RequestReplyHandler
// can be created from the registry that understands how to dispatch
// invocation requests to the registered functions as well as encode
// side-effects (e.g., sending messages to other functions or updating
// values in storage) as the response.
type StatefulFunctions interface {
// WithSpec registers a StatefulFunctionSpec, which will be
// used to build the runtime function. It returns an error
// if the specification is invalid and the handler
// fails to register the function.
WithSpec(spec StatefulFunctionSpec) error
// AsHandler creates a RequestReplyHandler from the registered
// function specs.
AsHandler() RequestReplyHandler
}
// The RequestReplyHandler processes messages
// from the runtime, invokes functions, and encodes
// side effects. The handler implements http.Handler
// so it can easily be embedded in standard Go server
// frameworks.
type RequestReplyHandler interface {
http.Handler
// Invoke method provides compliance with AWS Lambda handler
Invoke(ctx context.Context, payload []byte) ([]byte, error)
}
// StatefulFunctionsBuilder creates a new StatefulFunctions registry.
func StatefulFunctionsBuilder() StatefulFunctions {
return &handler{
module: map[TypeName]StatefulFunction{},
stateSpecs: map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec{},
}
}
type handler struct {
module map[TypeName]StatefulFunction
stateSpecs map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec
}
func (h *handler) WithSpec(spec StatefulFunctionSpec) error {
log.Printf("registering Stateful Function %v\n", spec.FunctionType)
if _, exists := h.module[spec.FunctionType]; exists {
err := fmt.Errorf("failed to register Stateful Function %s, there is already a spec registered under that type", spec.FunctionType)
log.Println(err.Error())
return err
}
if spec.Function == nil {
err := fmt.Errorf("failed to register Stateful Function %s, the Function instance cannot be nil", spec.FunctionType)
log.Println(err.Error())
return err
}
valueSpecs := make(map[string]*protocol.FromFunction_PersistedValueSpec, len(spec.States))
for _, state := range spec.States {
log.Printf("registering state specification %v\n", state)
if err := validateValueSpec(state); err != nil {
err := fmt.Errorf("failed to register Stateful Function %s: %w", spec.FunctionType, err)
log.Println(err.Error())
return err
}
expiration := &protocol.FromFunction_ExpirationSpec{}
switch state.Expiration.expirationType {
case none:
expiration.Mode = protocol.FromFunction_ExpirationSpec_NONE
case expireAfterWrite:
expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_WRITE
expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds()
case expireAfterCall:
expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_INVOKE
expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds()
}
valueSpecs[state.Name] = &protocol.FromFunction_PersistedValueSpec{
StateName: state.Name,
ExpirationSpec: expiration,
TypeTypename: state.ValueType.GetTypeName().String(),
}
}
h.module[spec.FunctionType] = spec.Function
h.stateSpecs[spec.FunctionType] = valueSpecs
return nil
}
func (h *handler) AsHandler() RequestReplyHandler {
return h
}
func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method != "POST" {
http.Error(writer, "invalid request method", http.StatusMethodNotAllowed)
return
}
contentType := request.Header.Get("Content-type")
if contentType != "" && contentType != "application/octet-stream" {
http.Error(writer, "invalid content type", http.StatusUnsupportedMediaType)
return
}
if request.Body == nil || request.ContentLength == 0 {
http.Error(writer, "empty request body", http.StatusBadRequest)
return
}
buffer := bytes.Buffer{}
if _, err := buffer.ReadFrom(request.Body); err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
response, err := h.Invoke(request.Context(), buffer.Bytes())
if err != nil {
log.Println(err.Error())
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
_, _ = writer.Write(response)
}
func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
toFunction := protocol.ToFunction{}
if err := proto.Unmarshal(payload, &toFunction); err != nil {
return nil, fmt.Errorf("failed to unmarshal ToFunction: %w", err)
}
fromFunction, err := h.invoke(ctx, &toFunction)
if err != nil {
return nil, err
}
return proto.Marshal(fromFunction)
}
func (h *handler) invoke(ctx context.Context, toFunction *protocol.ToFunction) (from *protocol.FromFunction, err error) {
batch := toFunction.GetInvocation()
self := addressFromInternal(batch.Target)
function, exists := h.module[self.FunctionType]
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
case error:
err = fmt.Errorf("failed to execute invocation for %s: %w", batch.Target, r)
default:
log.Fatal(r)
}
}
}()
if !exists {
return nil, fmt.Errorf("unknown function type %s", self.FunctionType)
}
storageFactory := newStorageFactory(batch, h.stateSpecs[self.FunctionType])
if missing := storageFactory.getMissingSpecs(); missing != nil {
log.Printf("missing state specs for function type %v", self)
for _, spec := range missing {
log.Printf("registering missing specs %v", spec)
}
return &protocol.FromFunction{
Response: &protocol.FromFunction_IncompleteInvocationContext_{
IncompleteInvocationContext: &protocol.FromFunction_IncompleteInvocationContext{
MissingValues: missing,
},
},
}, nil
}
storage := storageFactory.getStorage()
response := &protocol.FromFunction_InvocationResponse{}
for _, invocation := range batch.Invocations {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
sContext := statefunContext{
Mutex: new(sync.Mutex),
self: self,
storage: storage,
response: response,
}
var cancel context.CancelFunc
sContext.Context, cancel = context.WithCancel(ctx)
var caller Address
if invocation.Caller != nil {
caller = addressFromInternal(invocation.Caller)
}
sContext.caller = &caller
msg := Message{
target: batch.Target,
typedValue: invocation.Argument,
}
err = function.Invoke(&sContext, msg)
cancel()
if err != nil {
return
}
}
}
response.StateMutations = storage.getStateMutations()
from = &protocol.FromFunction{
Response: &protocol.FromFunction_InvocationResult{
InvocationResult: response,
},
}
return
}