blob: 3c02ee62af777b1e52e446ebd1036cd191e9c8a7 [file] [log] [blame]
// Copyright 2021-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package triple_protocol
import (
"context"
"fmt"
"net/http"
)
const (
defaultImplementationsSize = 4
)
// A Handler is the server-side implementation of a single RPC defined by a
// service schema.
//
// By default, Handlers support the Triple, gRPC, and gRPC-Web protocols with
// the binary Protobuf and JSON codecs. They support gzip compression using the
// standard library's [compress/gzip].
type Handler struct {
spec Spec
// key is group/version
implementations map[string]StreamingHandlerFunc
protocolHandlers []protocolHandler
allowMethod string // Allow header
acceptPost string // Accept-Post header
}
// NewUnaryHandler constructs a [Handler] for a request-response procedure.
func NewUnaryHandler(
procedure string,
reqInitFunc func() interface{},
unary func(context.Context, *Request) (*Response, error),
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
implementation := generateUnaryHandlerFunc(procedure, reqInitFunc, unary, config.Interceptor)
protocolHandlers := config.newProtocolHandlers(StreamTypeUnary)
hdl := &Handler{
spec: config.newSpec(StreamTypeUnary),
implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize),
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
return hdl
}
func generateUnaryHandlerFunc(
procedure string,
reqInitFunc func() interface{},
unary func(context.Context, *Request) (*Response, error),
interceptor Interceptor,
) StreamingHandlerFunc {
// Wrap the strongly-typed implementation so we can apply interceptors.
untyped := UnaryHandlerFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
// verify err
if err := ctx.Err(); err != nil {
return nil, err
}
typed, ok := request.(*Request)
if !ok {
return nil, errorf(CodeInternal, "unexpected handler request type %T", request)
}
res, err := unary(ctx, typed)
if res == nil && err == nil {
// This is going to panic during serialization. Debugging is much easier
// if we panic here instead, so we can include the procedure name.
panic(fmt.Sprintf("%s returned nil *triple.Response and nil error", procedure)) //nolint: forbidigo
}
return res, err
})
// todo: modify server func
if interceptor != nil {
untyped = interceptor.WrapUnaryHandler(untyped)
}
// receive and send
// conn should be responsible for marshal and unmarshal
// Given a stream, how should we call the unary function?
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
req := reqInitFunc()
if err := conn.Receive(req); err != nil {
return err
}
// wrap the specific msg
request := NewRequest(req)
request.spec = conn.Spec()
request.peer = conn.Peer()
request.header = conn.RequestHeader()
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
response, err := untyped(ctx, request)
if err != nil {
return err
}
// merge headers
mergeHeaders(conn.ResponseHeader(), response.Header())
mergeHeaders(conn.ResponseTrailer(), response.Trailer())
return conn.Send(response.Any())
}
return implementation
}
// NewClientStreamHandler constructs a [Handler] for a client streaming procedure.
func NewClientStreamHandler(
procedure string,
streamFunc func(context.Context, *ClientStream) (*Response, error),
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
implementation := generateClientStreamHandlerFunc(procedure, streamFunc, config.Interceptor)
protocolHandlers := config.newProtocolHandlers(StreamTypeClient)
hdl := &Handler{
spec: config.newSpec(StreamTypeClient),
implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize),
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
return hdl
}
func generateClientStreamHandlerFunc(
procedure string,
streamFunc func(context.Context, *ClientStream) (*Response, error),
interceptor Interceptor,
) StreamingHandlerFunc {
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
stream := &ClientStream{conn: conn}
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
res, err := streamFunc(ctx, stream)
if err != nil {
return err
}
if res == nil {
// This is going to panic during serialization. Debugging is much easier
// if we panic here instead, so we can include the procedure name.
panic(fmt.Sprintf("%s returned nil *triple.Response and nil error", procedure)) //nolint: forbidigo
}
mergeHeaders(conn.ResponseHeader(), res.header)
mergeHeaders(conn.ResponseTrailer(), res.trailer)
return conn.Send(res.Msg)
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
}
return implementation
}
// NewServerStreamHandler constructs a [Handler] for a server streaming procedure.
func NewServerStreamHandler(
procedure string,
reqInitFunc func() interface{},
streamFunc func(context.Context, *Request, *ServerStream) error,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
implementation := generateServerStreamHandlerFunc(procedure, reqInitFunc, streamFunc, config.Interceptor)
protocolHandlers := config.newProtocolHandlers(StreamTypeServer)
hdl := &Handler{
spec: config.newSpec(StreamTypeServer),
implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize),
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
return hdl
}
func generateServerStreamHandlerFunc(
procedure string,
reqInitFunc func() interface{},
streamFunc func(context.Context, *Request, *ServerStream) error,
interceptor Interceptor,
) StreamingHandlerFunc {
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
req := reqInitFunc()
if err := conn.Receive(req); err != nil {
return err
}
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
return streamFunc(
ctx,
&Request{
Msg: req,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
},
&ServerStream{conn: conn},
)
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
}
return implementation
}
// NewBidiStreamHandler constructs a [Handler] for a bidirectional streaming procedure.
func NewBidiStreamHandler(
procedure string,
streamFunc func(context.Context, *BidiStream) error,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
implementation := generateBidiStreamHandlerFunc(procedure, streamFunc, config.Interceptor)
protocolHandlers := config.newProtocolHandlers(StreamTypeBidi)
hdl := &Handler{
spec: config.newSpec(StreamTypeBidi),
implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize),
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
return hdl
}
func generateBidiStreamHandlerFunc(
procedure string,
streamFunc func(context.Context, *BidiStream) error,
interceptor Interceptor,
) StreamingHandlerFunc {
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
return streamFunc(
ctx,
&BidiStream{conn: conn},
)
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
}
return implementation
}
func (h *Handler) processImplementation(identifier string, implementation StreamingHandlerFunc) {
h.implementations[identifier] = implementation
}
// ServeHTTP implements [http.Handler].
func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
// We don't need to defer functions to close the request body or read to
// EOF: the stream we construct later on already does that, and we only
// return early when dealing with misbehaving clients. In those cases, it's
// okay if we can't re-use the connection.
isBidi := (h.spec.StreamType & StreamTypeBidi) == StreamTypeBidi
if isBidi && request.ProtoMajor < 2 {
// Clients coded to expect full-duplex connections may hang if they've
// mistakenly negotiated HTTP/1.1. To unblock them, we must close the
// underlying TCP connection.
responseWriter.Header().Set("Connection", "close")
responseWriter.WriteHeader(http.StatusHTTPVersionNotSupported)
return
}
// inspect headers
var protocolHandlers []protocolHandler
for _, handler := range h.protocolHandlers {
if _, ok := handler.Methods()[request.Method]; ok {
protocolHandlers = append(protocolHandlers, handler)
}
}
if len(protocolHandlers) == 0 {
responseWriter.Header().Set("Allow", h.allowMethod)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return
}
contentType := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
// inspect contentType
// Find our implementation of the RPC protocol in use.
var protocolHdl protocolHandler
for _, handler := range protocolHandlers {
if handler.CanHandlePayload(request, contentType) {
protocolHdl = handler
break
}
}
if protocolHdl == nil {
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
return
}
// Establish a stream and serve the RPC.
setHeaderCanonical(request.Header, headerContentType, contentType)
// process context
ctx, cancel, timeoutErr := protocolHdl.SetTimeout(request) //nolint: contextcheck
if timeoutErr != nil {
ctx = request.Context()
}
if cancel != nil {
defer cancel()
}
// create stream
connCloser, ok := protocolHdl.NewConn(
responseWriter,
request.WithContext(ctx),
)
if !ok {
// Failed to create stream, usually because client used an unknown
// compression algorithm. Nothing further to do.
return
}
if timeoutErr != nil {
_ = connCloser.Close(timeoutErr)
return
}
// invoke implementation
svcGroup := request.Header.Get(tripleServiceGroup)
svcVersion := request.Header.Get(tripleServiceVersion)
implementation := h.implementations[getIdentifier(svcGroup, svcVersion)]
// todo(DMwangnima): inspect ok
_ = connCloser.Close(implementation(ctx, connCloser))
}
type handlerConfig struct {
CompressionPools map[string]*compressionPool
CompressionNames []string
Codecs map[string]Codec
CompressMinBytes int
Interceptor Interceptor
Procedure string
HandleGRPC bool
RequireTripleProtocolHeader bool
IdempotencyLevel IdempotencyLevel
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
Group string
Version string
}
func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig {
protoPath := extractProtoPath(procedure)
config := handlerConfig{
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
Codecs: make(map[string]Codec),
HandleGRPC: true,
BufferPool: newBufferPool(),
}
withProtoBinaryCodec().applyToHandler(&config)
withProtoJSONCodecs().applyToHandler(&config)
withGzip().applyToHandler(&config)
for _, opt := range options {
opt.applyToHandler(&config)
}
return &config
}
func (c *handlerConfig) newSpec(streamType StreamType) Spec {
return Spec{
Procedure: c.Procedure,
StreamType: streamType,
IdempotencyLevel: c.IdempotencyLevel,
}
}
func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHandler {
// initialize protocol
var protocols []protocol
if streamType == StreamTypeUnary {
protocols = append(protocols, &protocolTriple{})
}
if c.HandleGRPC {
protocols = append(protocols, &protocolGRPC{})
}
// protocol -> protocolHandler
handlers := make([]protocolHandler, 0, len(protocols))
// initialize codec and compressor
codecs := newReadOnlyCodecs(c.Codecs)
compressors := newReadOnlyCompressionPools(
c.CompressionPools,
c.CompressionNames,
)
for _, protocol := range protocols {
handlers = append(handlers, protocol.NewHandler(&protocolHandlerParams{
Spec: c.newSpec(streamType),
Codecs: codecs,
CompressionPools: compressors,
// config content
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
ReadMaxBytes: c.ReadMaxBytes,
SendMaxBytes: c.SendMaxBytes,
RequireTripleProtocolHeader: c.RequireTripleProtocolHeader,
IdempotencyLevel: c.IdempotencyLevel,
}))
}
return handlers
}
func getIdentifier(group, version string) string {
return group + "/" + version
}
func newStreamHandler(
procedure string,
streamType StreamType,
implementation StreamingHandlerFunc,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
if ic := config.Interceptor; ic != nil {
implementation = ic.WrapStreamingHandler(implementation)
}
protocolHandlers := config.newProtocolHandlers(streamType)
hdl := &Handler{
spec: config.newSpec(streamType),
implementations: make(map[string]StreamingHandlerFunc, defaultImplementationsSize),
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
return hdl
}