blob: f175979fc1666853811935f397ce595be22403c0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bufstudioagent
import (
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
)
import (
"github.com/bufbuild/connect-go"
"go.uber.org/zap"
"golang.org/x/net/http2"
"google.golang.org/protobuf/proto"
)
import (
studiov1alpha1 "github.com/apache/dubbo-kubernetes/pkg/bufman/gen/proto/go/studio/v1alpha1"
"github.com/apache/dubbo-kubernetes/pkg/bufman/pkg/protoencoding"
)
// MaxMessageSizeBytesDefault determines the maximum number of bytes to read
// from the request body.
const MaxMessageSizeBytesDefault = 1024 * 1024 * 5
// plainPostHandler implements a POST handler for forwarding requests that can
// be called with simple CORS requests.
//
// Simple CORS requests are limited [1] to certain headers and content types, so
// this handler expects base64 encoded protobuf messages in the body and writes
// out base64 encoded protobuf messages to be able to use Content-Type: text/plain.
//
// Because of the content-type restriction we do not define a protobuf service
// that gets served by connect but instead use a plain post handler.
//
// [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#simple_requests).
type plainPostHandler struct {
Logger *zap.Logger
MaxMessageSizeBytes int64
B64Encoding *base64.Encoding
TLSClient *http.Client
H2CClient *http.Client
DisallowedHeaders map[string]struct{}
ForwardHeaders map[string]string
}
func newPlainPostHandler(
logger *zap.Logger,
disallowedHeaders map[string]struct{},
forwardHeaders map[string]string,
tlsClientConfig *tls.Config,
) *plainPostHandler {
canonicalDisallowedHeaders := make(map[string]struct{}, len(disallowedHeaders))
for k := range disallowedHeaders {
canonicalDisallowedHeaders[textproto.CanonicalMIMEHeaderKey(k)] = struct{}{}
}
canonicalForwardHeaders := make(map[string]string, len(forwardHeaders))
for k, v := range forwardHeaders {
canonicalForwardHeaders[textproto.CanonicalMIMEHeaderKey(k)] = v
}
return &plainPostHandler{
B64Encoding: base64.StdEncoding,
DisallowedHeaders: canonicalDisallowedHeaders,
ForwardHeaders: canonicalForwardHeaders,
H2CClient: &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(netw, addr string, config *tls.Config) (net.Conn, error) {
return net.Dial(netw, addr)
},
},
},
Logger: logger,
MaxMessageSizeBytes: MaxMessageSizeBytesDefault,
TLSClient: &http.Client{
Transport: &http2.Transport{
TLSClientConfig: tlsClientConfig,
},
},
}
}
func (i *plainPostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("content-type") != "text/plain" {
http.Error(w, "", http.StatusUnsupportedMediaType)
return
}
bodyBytes, err := io.ReadAll(
base64.NewDecoder(
i.B64Encoding,
http.MaxBytesReader(w, r.Body, i.MaxMessageSizeBytes),
),
)
if err != nil {
if b64Err := new(base64.CorruptInputError); errors.As(err, &b64Err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
return
}
envelopeRequest := &studiov1alpha1.InvokeRequest{}
if err := protoencoding.NewWireUnmarshaler(nil).Unmarshal(bodyBytes, envelopeRequest); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
request := connect.NewRequest(bytes.NewBuffer(envelopeRequest.GetBody()))
for _, header := range envelopeRequest.Headers {
if _, ok := i.DisallowedHeaders[textproto.CanonicalMIMEHeaderKey(header.Key)]; ok {
http.Error(w, fmt.Sprintf("header %q disallowed by agent", header.Key), http.StatusBadRequest)
return
}
for _, value := range header.Value {
request.Header().Add(header.Key, value)
}
}
for fromHeader, toHeader := range i.ForwardHeaders {
headerValues := r.Header.Values(fromHeader)
if len(headerValues) > 0 {
request.Header().Del(toHeader)
for _, headerValue := range headerValues {
request.Header().Add(toHeader, headerValue)
}
}
}
targetURL, err := url.Parse(envelopeRequest.GetTarget())
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var httpClient *http.Client
switch targetURL.Scheme {
case "http":
httpClient = i.H2CClient
case "https":
httpClient = i.TLSClient
default:
http.Error(w, fmt.Sprintf("must specify http or https url scheme, got %q", targetURL.Scheme), http.StatusBadRequest)
return
}
clientOptions, err := connectClientOptionsFromContentType(request.Header().Get("Content-Type"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
client := connect.NewClient[bytes.Buffer, bytes.Buffer](
httpClient,
targetURL.String(),
clientOptions...,
)
// TODO(rvanginkel) should this context be cloned to remove attached values (but keep timeout)?
response, err := client.CallUnary(r.Context(), request)
if err != nil {
// We need to differentiate client errors from server errors. In the former,
// trigger a `StatusBadGateway` result, and in the latter surface whatever
// error information came back from the server.
//
// Any error here is expected to be wrapped in a `connect.Error` struct. We
// need to check *first* if it's not a wire error, so we can assume the
// request never left the client, or a response never arrived from the
// server. In those scenarios we trigger a `StatusBadGateway` to signal
// that the upstream server is unreachable or in a bad status...
if !connect.IsWireError(err) {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
// ... but if a response was received from the server, we assume there's
// error information from the server we can surface to the user by including
// it in the headers response, unless it is a `CodeUnknown` error. Connect
// marks any issues connecting with the `CodeUnknown` error.
if connectErr := new(connect.Error); errors.As(err, &connectErr) {
if connectErr.Code() == connect.CodeUnknown {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
i.writeProtoMessage(w, &studiov1alpha1.InvokeResponse{
// connectErr.Meta contains the trailers for the
// caller to find out the error details.
Headers: goHeadersToProtoHeaders(connectErr.Meta()),
})
return
}
i.Logger.Warn(
"non_connect_unary_error",
zap.Error(err),
)
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
i.writeProtoMessage(w, &studiov1alpha1.InvokeResponse{
Headers: goHeadersToProtoHeaders(response.Header()),
Body: response.Msg.Bytes(),
Trailers: goHeadersToProtoHeaders(response.Trailer()),
})
}
func connectClientOptionsFromContentType(contentType string) ([]connect.ClientOption, error) {
switch contentType {
case "application/grpc", "application/grpc+proto":
return []connect.ClientOption{
connect.WithGRPC(),
connect.WithCodec(&bufferCodec{name: "proto"}),
}, nil
case "application/grpc+json":
return []connect.ClientOption{
connect.WithGRPC(),
connect.WithCodec(&bufferCodec{name: "json"}),
}, nil
case "application/json":
return []connect.ClientOption{
connect.WithCodec(&bufferCodec{name: "json"}),
}, nil
case "application/proto":
return []connect.ClientOption{
connect.WithCodec(&bufferCodec{name: "proto"}),
}, nil
default:
return nil, fmt.Errorf("unknown Content-Type: %q", contentType)
}
}
func (i *plainPostHandler) writeProtoMessage(w http.ResponseWriter, message proto.Message) {
responseProtoBytes, err := protoencoding.NewWireMarshaler().Marshal(message)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
responseB64Bytes := make([]byte, i.B64Encoding.EncodedLen(len(responseProtoBytes)))
i.B64Encoding.Encode(responseB64Bytes, responseProtoBytes)
w.Header().Set("Content-Type", "text/plain")
if n, err := w.Write(responseB64Bytes); n != len(responseB64Bytes) && err != nil {
i.Logger.Error(
"write_error",
zap.Int("expected_bytes", len(responseB64Bytes)),
zap.Int("actual_bytes", n),
zap.Error(err),
)
}
}
func goHeadersToProtoHeaders(in http.Header) []*studiov1alpha1.Headers {
var out []*studiov1alpha1.Headers
for k, v := range in {
out = append(out, &studiov1alpha1.Headers{
Key: k,
Value: v,
})
}
return out
}