blob: d3b6ed47760eaffca279d69ff1adc2d786e38fd9 [file] [log] [blame]
// Copyright 2021-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package triple_protocol_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
import (
triple "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/assert"
pingv1 "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/gen/proto/connect/ping/v1"
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol/internal/gen/proto/connect/ping/v1/pingv1connect"
)
func TestOnionOrderingEndToEnd(t *testing.T) {
t.Parallel()
// Helper function: returns a function that asserts that there's some value
// set for header "expect", and adds a value for header "add".
newInspector := func(expect, add string) func(triple.Spec, http.Header) {
return func(spec triple.Spec, header http.Header) {
if expect != "" {
assert.NotZero(
t,
header.Get(expect),
assert.Sprintf(
"%s (IsClient %v): header %q missing: %v",
spec.Procedure,
spec.IsClient,
expect,
header,
),
)
}
header.Set(add, "v")
}
}
// Helper function: asserts that there's a value present for header keys
// "one", "two", "three", and "four".
assertAllPresent := func(spec triple.Spec, header http.Header) {
for _, key := range []string{"one", "two", "three", "four"} {
assert.NotZero(
t,
header.Get(key),
assert.Sprintf(
"%s (IsClient %v): checking all headers, %q missing: %v",
spec.Procedure,
spec.IsClient,
key,
header,
),
)
}
}
// The client and handler interceptor onions are the meat of the test. The
// order of interceptor execution must be the same for unary and streaming
// procedures.
//
// Requests should fall through the client onion from top to bottom, traverse
// the network, and then fall through the handler onion from top to bottom.
// Responses should climb up the handler onion, traverse the network, and
// then climb up the client onion.
//
// The request and response sides of this onion are numbered to make the
// intended order clear.
clientOnion := triple.WithInterceptors(
newHeaderInterceptor(
// 1 (start). request: should see protocol-related headers
func(_ triple.Spec, h http.Header) {
assert.NotZero(t, h.Get("Content-Type"))
},
// 12 (end). response: check "one"-"four"
assertAllPresent,
),
newHeaderInterceptor(
newInspector("", "one"), // 2. request: add header "one"
newInspector("three", "four"), // 11. response: check "three", add "four"
),
newHeaderInterceptor(
newInspector("one", "two"), // 3. request: check "one", add "two"
newInspector("two", "three"), // 10. response: check "two", add "three"
),
)
handlerOnion := triple.WithInterceptors(
newHeaderInterceptor(
newInspector("two", "three"), // 4. request: check "two", add "three"
newInspector("one", "two"), // 9. response: check "one", add "two"
),
newHeaderInterceptor(
newInspector("three", "four"), // 5. request: check "three", add "four"
newInspector("", "one"), // 8. response: add "one"
),
newHeaderInterceptor(
assertAllPresent, // 6. request: check "one"-"four"
nil, // 7. response: no-op
),
)
mux := http.NewServeMux()
mux.Handle(
pingv1connect.NewPingServiceHandler(
pingServer{},
handlerOnion,
),
)
server := httptest.NewServer(mux)
defer server.Close()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
clientOnion,
)
err := client.Ping(context.Background(), triple.NewRequest(&pingv1.PingRequest{Number: 10}), triple.NewResponse(&pingv1.PingResponse{}))
assert.Nil(t, err)
//responses, err := client.CountUp(context.Background(), triple.NewRequest(&pingv1.CountUpRequest{Number: 10}))
//assert.Nil(t, err)
//var sum int64
//for responses.Receive(&pingv1.CountUpResponse{}) {
// msg := responses.Msg().(pingv1.CountUpResponse)
// sum += msg.Number
//}
//assert.Equal(t, sum, 55)
//assert.Nil(t, responses.Close())
}
//func TestEmptyUnaryInterceptorFunc(t *testing.T) {
// t.Parallel()
// mux := http.NewServeMux()
// interceptor := triple.UnaryInterceptorFunc(func(next triple.UnaryFunc) triple.UnaryFunc {
// return func(ctx context.Context, request triple.AnyRequest, response triple.AnyResponse) error {
// return next(ctx, request, response)
// }
// })
// mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, triple.WithInterceptors(interceptor)))
// server := httptest.NewServer(mux)
// t.Cleanup(server.Close)
// connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, triple.WithInterceptors(interceptor))
// err := connectClient.Ping(context.Background(), triple.NewRequest(&pingv1.PingRequest{}), triple.NewResponse(&pingv1.PingResponse{}))
// assert.Nil(t, err)
// sumStream, err := connectClient.Sum(context.Background())
// assert.Nil(t, err)
// assert.Nil(t, sumStream.Send(&pingv1.SumRequest{Number: 1}))
// err = sumStream.CloseAndReceive(triple.NewResponse(&pingv1.SumResponse{}))
// assert.Nil(t, err)
// countUpStream, err := connectClient.CountUp(context.Background(), triple.NewRequest(&pingv1.CountUpRequest{}))
// assert.Nil(t, err)
// for countUpStream.Receive(&pingv1.CountUpResponse{}) {
// assert.NotNil(t, countUpStream.Msg())
// }
// assert.Nil(t, countUpStream.Close())
//}
// headerInterceptor makes it easier to write interceptors that inspect or
// mutate HTTP headers. It applies the same logic to unary and streaming
// procedures, wrapping the send or receive side of the stream as appropriate.
//
// It's useful as a testing harness to make sure that we're chaining
// interceptors in the correct order.
type headerInterceptor struct {
inspectRequestHeader func(triple.Spec, http.Header)
inspectResponseHeader func(triple.Spec, http.Header)
}
// newHeaderInterceptor constructs a headerInterceptor. Nil function pointers
// are treated as no-ops.
func newHeaderInterceptor(
inspectRequestHeader func(triple.Spec, http.Header),
inspectResponseHeader func(triple.Spec, http.Header),
) *headerInterceptor {
interceptor := headerInterceptor{
inspectRequestHeader: inspectRequestHeader,
inspectResponseHeader: inspectResponseHeader,
}
if interceptor.inspectRequestHeader == nil {
interceptor.inspectRequestHeader = func(_ triple.Spec, _ http.Header) {}
}
if interceptor.inspectResponseHeader == nil {
interceptor.inspectResponseHeader = func(_ triple.Spec, _ http.Header) {}
}
return &interceptor
}
func (h *headerInterceptor) WrapUnaryHandler(next triple.UnaryHandlerFunc) triple.UnaryHandlerFunc {
return func(ctx context.Context, request triple.AnyRequest) (triple.AnyResponse, error) {
h.inspectRequestHeader(request.Spec(), request.Header())
response, err := next(ctx, request)
if err != nil {
return nil, err
}
h.inspectResponseHeader(request.Spec(), response.Header())
return response, nil
}
}
func (h *headerInterceptor) WrapUnary(next triple.UnaryFunc) triple.UnaryFunc {
return func(ctx context.Context, req triple.AnyRequest, res triple.AnyResponse) error {
h.inspectRequestHeader(req.Spec(), req.Header())
err := next(ctx, req, res)
if err != nil {
return err
}
h.inspectResponseHeader(req.Spec(), res.Header())
return nil
}
}
func (h *headerInterceptor) WrapStreamingClient(next triple.StreamingClientFunc) triple.StreamingClientFunc {
return func(ctx context.Context, spec triple.Spec) triple.StreamingClientConn {
return &headerInspectingClientConn{
StreamingClientConn: next(ctx, spec),
inspectRequestHeader: h.inspectRequestHeader,
inspectResponseHeader: h.inspectResponseHeader,
}
}
}
func (h *headerInterceptor) WrapStreamingHandler(next triple.StreamingHandlerFunc) triple.StreamingHandlerFunc {
return func(ctx context.Context, conn triple.StreamingHandlerConn) error {
h.inspectRequestHeader(conn.Spec(), conn.RequestHeader())
return next(ctx, &headerInspectingHandlerConn{
StreamingHandlerConn: conn,
inspectResponseHeader: h.inspectResponseHeader,
})
}
}
type headerInspectingHandlerConn struct {
triple.StreamingHandlerConn
inspectedResponse bool
inspectResponseHeader func(triple.Spec, http.Header)
}
func (hc *headerInspectingHandlerConn) Send(msg interface{}) error {
if !hc.inspectedResponse {
hc.inspectResponseHeader(hc.Spec(), hc.ResponseHeader())
hc.inspectedResponse = true
}
return hc.StreamingHandlerConn.Send(msg)
}
type headerInspectingClientConn struct {
triple.StreamingClientConn
inspectedRequest bool
inspectRequestHeader func(triple.Spec, http.Header)
inspectedResponse bool
inspectResponseHeader func(triple.Spec, http.Header)
}
func (cc *headerInspectingClientConn) Send(msg interface{}) error {
if !cc.inspectedRequest {
cc.inspectRequestHeader(cc.Spec(), cc.RequestHeader())
cc.inspectedRequest = true
}
return cc.StreamingClientConn.Send(msg)
}
func (cc *headerInspectingClientConn) Receive(msg interface{}) error {
err := cc.StreamingClientConn.Receive(msg)
if !cc.inspectedResponse {
cc.inspectResponseHeader(cc.Spec(), cc.ResponseHeader())
cc.inspectedResponse = true
}
return err
}