blob: cc539cc3f7b3df97502b8059e4980ba90f31c0a9 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bufstudioagent
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
)
import (
"github.com/bufbuild/connect-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/protobuf/proto"
)
import (
studiov1alpha1 "github.com/apache/dubbo-kubernetes/pkg/bufman/gen/proto/go/studio/v1alpha1"
"github.com/apache/dubbo-kubernetes/pkg/bufman/pkg/protoencoding"
)
const (
echoPath = "/echo.Service/EchoEcho"
errorPath = "/error.Service/Error"
)
func TestPlainPostHandlerTLS(t *testing.T) {
upstreamServerTLS := newTestConnectServer(t, true)
defer upstreamServerTLS.Close()
testPlainPostHandler(t, upstreamServerTLS)
testPlainPostHandlerErrors(t, upstreamServerTLS)
}
func TestPlainPostHandlerH2C(t *testing.T) {
upstreamServerH2C := newTestConnectServer(t, false)
defer upstreamServerH2C.Close()
testPlainPostHandler(t, upstreamServerH2C)
testPlainPostHandlerErrors(t, upstreamServerH2C)
}
func testPlainPostHandler(t *testing.T, upstreamServer *httptest.Server) {
agentServer := httptest.NewTLSServer(
NewHandler(
zaptest.NewLogger(t),
"https://example.buf.build",
upstreamServer.TLS,
nil,
map[string]string{"foo": "bar"},
false,
),
)
defer agentServer.Close()
t.Run("content_type_grpc_proto", func(t *testing.T) {
requestProto := &studiov1alpha1.InvokeRequest{
Target: upstreamServer.URL + echoPath,
Headers: goHeadersToProtoHeaders(http.Header{
"Content-Type": []string{"application/grpc+proto"},
}),
Body: []byte("echothis"),
}
requestBytes := protoMarshalBase64(t, requestProto)
request, err := http.NewRequest(http.MethodPost, agentServer.URL, bytes.NewReader(requestBytes))
require.NoError(t, err)
request.Header.Set("Content-Type", "text/plain")
request.Header.Set("Origin", "https://example.buf.build")
request.Header.Set("Foo", "foo-value")
response, err := agentServer.Client().Do(request)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.Equal(t, "https://example.buf.build", response.Header.Get("Access-Control-Allow-Origin"))
responseBytes, err := io.ReadAll(response.Body)
assert.NoError(t, err)
invokeResponse := &studiov1alpha1.InvokeResponse{}
protoUnmarshalBase64(t, responseBytes, invokeResponse)
upstreamResponseHeaders := make(http.Header)
addProtoHeadersToGoHeader(invokeResponse.Headers, upstreamResponseHeaders)
addProtoHeadersToGoHeader(invokeResponse.Trailers, upstreamResponseHeaders)
assert.Equal(t, "0", upstreamResponseHeaders.Get("grpc-status"))
assert.Equal(t, []byte("echo: echothis"), invokeResponse.Body)
assert.Equal(t, "foo-value", upstreamResponseHeaders.Get("Echo-Bar"))
})
t.Run("content_type_application_proto", func(t *testing.T) {
requestProto := &studiov1alpha1.InvokeRequest{
Target: upstreamServer.URL + echoPath,
Headers: goHeadersToProtoHeaders(http.Header{
"Content-Type": []string{"application/proto"},
}),
Body: []byte("echothis"),
}
requestBytes := protoMarshalBase64(t, requestProto)
request, err := http.NewRequest(http.MethodPost, agentServer.URL, bytes.NewReader(requestBytes))
require.NoError(t, err)
request.Header.Set("Content-Type", "text/plain")
request.Header.Set("Origin", "https://example.buf.build")
request.Header.Set("Foo", "foo-value")
response, err := agentServer.Client().Do(request)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusOK, response.StatusCode)
assert.Equal(t, "https://example.buf.build", response.Header.Get("Access-Control-Allow-Origin"))
responseBytes, err := io.ReadAll(response.Body)
assert.NoError(t, err)
invokeResponse := &studiov1alpha1.InvokeResponse{}
protoUnmarshalBase64(t, responseBytes, invokeResponse)
upstreamResponseHeaders := make(http.Header)
addProtoHeadersToGoHeader(invokeResponse.Headers, upstreamResponseHeaders)
addProtoHeadersToGoHeader(invokeResponse.Trailers, upstreamResponseHeaders)
assert.Equal(t, "", upstreamResponseHeaders.Get("grpc-status"))
assert.Equal(t, []byte("echo: echothis"), invokeResponse.Body)
assert.Equal(t, "foo-value", upstreamResponseHeaders.Get("Echo-Bar"))
})
}
func testPlainPostHandlerErrors(t *testing.T, upstreamServer *httptest.Server) {
agentServer := httptest.NewTLSServer(
NewHandler(
zaptest.NewLogger(t),
"https://example.buf.build",
upstreamServer.TLS,
map[string]struct{}{"forbidden-header": {}},
nil,
false,
),
)
defer agentServer.Close()
t.Run("forbidden_header", func(t *testing.T) {
requestProto := &studiov1alpha1.InvokeRequest{
Target: upstreamServer.URL + echoPath,
Headers: goHeadersToProtoHeaders(http.Header{
"forbidden-header": []string{"<tokens>"},
}),
}
requestBytes := protoMarshalBase64(t, requestProto)
request, err := http.NewRequest(http.MethodPost, agentServer.URL, bytes.NewReader(requestBytes))
require.NoError(t, err)
request.Header.Set("Content-Type", "text/plain")
response, err := agentServer.Client().Do(request)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusBadRequest, response.StatusCode)
})
t.Run("error_response", func(t *testing.T) {
requestProto := &studiov1alpha1.InvokeRequest{
Target: upstreamServer.URL + errorPath,
Headers: goHeadersToProtoHeaders(http.Header{
"Content-Type": []string{"application/grpc"},
}),
Body: []byte("something"),
}
requestBytes := protoMarshalBase64(t, requestProto)
request, err := http.NewRequest(http.MethodPost, agentServer.URL, bytes.NewReader(requestBytes))
require.NoError(t, err)
request.Header.Set("Content-Type", "text/plain")
response, err := agentServer.Client().Do(request)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusOK, response.StatusCode)
responseBytes, err := io.ReadAll(response.Body)
assert.NoError(t, err)
invokeResponse := &studiov1alpha1.InvokeResponse{}
protoUnmarshalBase64(t, responseBytes, invokeResponse)
upstreamResponseHeaders := make(http.Header)
addProtoHeadersToGoHeader(invokeResponse.Headers, upstreamResponseHeaders)
addProtoHeadersToGoHeader(invokeResponse.Trailers, upstreamResponseHeaders)
assert.Equal(t, strconv.Itoa(int(connect.CodeFailedPrecondition)), upstreamResponseHeaders.Get("grpc-status"))
assert.Equal(t, "something", upstreamResponseHeaders.Get("grpc-message"))
})
t.Run("invalid_upstream", func(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
go func() {
conn, err := listener.Accept()
require.NoError(t, err)
require.NoError(t, conn.Close())
}()
defer listener.Close()
requestProto := &studiov1alpha1.InvokeRequest{
Target: "http://" + listener.Addr().String(),
Headers: goHeadersToProtoHeaders(http.Header{
"Content-Type": []string{"application/grpc"},
}),
}
requestBytes := protoMarshalBase64(t, requestProto)
request, err := http.NewRequest(http.MethodPost, agentServer.URL, bytes.NewReader(requestBytes))
require.NoError(t, err)
request.Header.Set("Content-Type", "text/plain")
response, err := agentServer.Client().Do(request)
require.NoError(t, err)
defer response.Body.Close()
assert.Equal(t, http.StatusBadGateway, response.StatusCode)
})
}
func newTestConnectServer(t *testing.T, tls bool) *httptest.Server {
mux := http.NewServeMux()
// echoPath echoes all incoming headers (prefixed with "Echo-") and the
// body bytes prefixed with "echo: "
mux.Handle(echoPath, connect.NewUnaryHandler(
echoPath,
func(ctx context.Context, r *connect.Request[bytes.Buffer]) (*connect.Response[bytes.Buffer], error) {
response := connect.NewResponse(bytes.NewBuffer(append([]byte("echo: "), r.Msg.Bytes()...)))
for header, values := range r.Header() {
for _, value := range values {
response.Header().Add("Echo-"+header, value)
}
}
return response, nil
},
connect.WithCodec(&bufferCodec{name: "proto"}),
))
// errorPath returns the body as error message with code failed precondition
mux.Handle(errorPath, connect.NewUnaryHandler(
errorPath,
func(ctx context.Context, r *connect.Request[bytes.Buffer]) (*connect.Response[bytes.Buffer], error) {
return nil, connect.NewError(connect.CodeFailedPrecondition, errors.New(r.Msg.String()))
},
connect.WithCodec(&bufferCodec{name: "proto"}),
))
if tls {
upstreamServerTLS := httptest.NewUnstartedServer(mux)
upstreamServerTLS.EnableHTTP2 = true
upstreamServerTLS.StartTLS()
certpool := x509.NewCertPool()
certpool.AddCert(upstreamServerTLS.Certificate())
upstreamServerTLS.TLS.RootCAs = certpool
return upstreamServerTLS
}
return httptest.NewServer(h2c.NewHandler(mux, &http2.Server{}))
}
func protoMarshalBase64(t *testing.T, message proto.Message) []byte {
protoBytes, err := protoencoding.NewWireMarshaler().Marshal(message)
require.NoError(t, err)
base64Bytes := make([]byte, base64.StdEncoding.EncodedLen(len(protoBytes)))
base64.StdEncoding.Encode(base64Bytes, protoBytes)
return base64Bytes
}
func protoUnmarshalBase64(t *testing.T, base64Bytes []byte, message proto.Message) {
protoBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes)))
actualLen, err := base64.StdEncoding.Decode(protoBytes, base64Bytes)
require.NoError(t, err)
protoBytes = protoBytes[:actualLen]
require.NoError(t, protoencoding.NewWireUnmarshaler(nil).Unmarshal(protoBytes, message))
}
func addProtoHeadersToGoHeader(fromHeaders []*studiov1alpha1.Headers, toHeaders http.Header) {
for _, meta := range fromHeaders {
for _, value := range meta.Value {
toHeaders.Add(meta.Key, value)
}
}
}