feat: inject attachments for triple (#2589)
* feat: inject attachments for triple
* inject attachments for server side
* lowercase the attachments keys
* add unit tests
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index 1a403cc..9172796 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -20,7 +20,9 @@
import (
"context"
"fmt"
+ "net/http"
"reflect"
+ "strings"
"sync"
)
@@ -237,8 +239,10 @@
// triple idl mode and old triple idl mode
args = append(args, req.Msg)
}
- // todo: inject method.Meta to attachments
- invo := invocation.NewRPCInvocation(m.Name, args, nil)
+ attachments := generateAttachments(req.Header())
+ // inject attachments
+ ctx = context.WithValue(ctx, constant.AttachmentKey, attachments)
+ invo := invocation.NewRPCInvocation(m.Name, args, attachments)
res := invoker.Invoke(ctx, invo)
// todo(DMwangnima): modify InfoInvoker to get a unified processing logic
// please refer to server/InfoInvoker.Invoke()
@@ -257,7 +261,10 @@
func(ctx context.Context, stream *tri.ClientStream) (*tri.Response, error) {
var args []interface{}
args = append(args, m.StreamInitFunc(stream))
- invo := invocation.NewRPCInvocation(m.Name, args, nil)
+ attachments := generateAttachments(stream.RequestHeader())
+ // inject attachments
+ ctx = context.WithValue(ctx, constant.AttachmentKey, attachments)
+ invo := invocation.NewRPCInvocation(m.Name, args, attachments)
res := invoker.Invoke(ctx, invo)
return res.Result().(*tri.Response), res.Error()
},
@@ -267,10 +274,13 @@
_ = s.triServer.RegisterServerStreamHandler(
procedure,
m.ReqInitFunc,
- func(ctx context.Context, request *tri.Request, stream *tri.ServerStream) error {
+ func(ctx context.Context, req *tri.Request, stream *tri.ServerStream) error {
var args []interface{}
- args = append(args, request.Msg, m.StreamInitFunc(stream))
- invo := invocation.NewRPCInvocation(m.Name, args, nil)
+ args = append(args, req.Msg, m.StreamInitFunc(stream))
+ attachments := generateAttachments(req.Header())
+ // inject attachments
+ ctx = context.WithValue(ctx, constant.AttachmentKey, attachments)
+ invo := invocation.NewRPCInvocation(m.Name, args, attachments)
res := invoker.Invoke(ctx, invo)
return res.Error()
},
@@ -282,7 +292,10 @@
func(ctx context.Context, stream *tri.BidiStream) error {
var args []interface{}
args = append(args, m.StreamInitFunc(stream))
- invo := invocation.NewRPCInvocation(m.Name, args, nil)
+ attachments := generateAttachments(stream.RequestHeader())
+ // inject attachments
+ ctx = context.WithValue(ctx, constant.AttachmentKey, attachments)
+ invo := invocation.NewRPCInvocation(m.Name, args, attachments)
res := invoker.Invoke(ctx, invo)
return res.Error()
},
@@ -409,3 +422,14 @@
return &info
}
+
+// generateAttachments transfer http.Header to map[string]interface{} and make all keys lowercase
+func generateAttachments(header http.Header) map[string]interface{} {
+ attachments := make(map[string]interface{}, len(header))
+ for key, val := range header {
+ lowerKey := strings.ToLower(key)
+ attachments[lowerKey] = val
+ }
+
+ return attachments
+}
diff --git a/protocol/triple/server_test.go b/protocol/triple/server_test.go
new file mode 100644
index 0000000..ac47ad4
--- /dev/null
+++ b/protocol/triple/server_test.go
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_generateAttachments(t *testing.T) {
+ tests := []struct {
+ desc string
+ input func() http.Header
+ expect func(t *testing.T, res map[string]interface{})
+ }{
+ {
+ desc: "empty header",
+ input: func() http.Header {
+ return http.Header{}
+ },
+ expect: func(t *testing.T, res map[string]interface{}) {
+ assert.Zero(t, len(res))
+ },
+ },
+ {
+ desc: "normal header with lowercase keys",
+ input: func() http.Header {
+ header := make(http.Header)
+ header.Set("key1", "val1")
+ header.Set("key2", "val2_1")
+ header.Add("key2", "val2_2")
+ return header
+ },
+ expect: func(t *testing.T, res map[string]interface{}) {
+ assert.Equal(t, 2, len(res))
+ assert.Equal(t, []string{"val1"}, res["key1"])
+ assert.Equal(t, []string{"val2_1", "val2_2"}, res["key2"])
+ },
+ },
+ {
+ desc: "normal header with uppercase keys",
+ input: func() http.Header {
+ header := make(http.Header)
+ header.Set("Key1", "val1")
+ header.Set("Key2", "val2_1")
+ header.Add("Key2", "val2_2")
+ return header
+ },
+ expect: func(t *testing.T, res map[string]interface{}) {
+ assert.Equal(t, 2, len(res))
+ assert.Equal(t, []string{"val1"}, res["key1"])
+ assert.Equal(t, []string{"val2_1", "val2_2"}, res["key2"])
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ atta := generateAttachments(test.input())
+ test.expect(t, atta)
+ })
+ }
+}
diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go
index 2e83765..e087784 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -19,8 +19,11 @@
import (
"context"
+ "errors"
"fmt"
"sync"
+
+ tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
)
import (
@@ -33,6 +36,10 @@
"dubbo.apache.org/dubbo-go/v3/protocol"
)
+var triAttachmentKeys = []string{
+ constant.InterfaceKey, constant.TokenKey, constant.TimeoutKey,
+}
+
type TripleInvoker struct {
protocol.BaseInvoker
quitOnce sync.Once
@@ -73,22 +80,14 @@
result.SetError(protocol.ErrClientClosed)
return &result
}
- callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
- if !ok {
- panic("Miss CallType to invoke TripleInvoker")
+
+ ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
+ if err != nil {
+ result.SetError(err)
+ return &result
}
- callType, ok := callTypeRaw.(string)
- if !ok {
- panic(fmt.Sprintf("CallType should be string, but got %v", callTypeRaw))
- }
- // please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen
- // e.g. Client.CallUnary(... req, resp []interface, ...)
- // inRaw represents req and resp, inRawLen represents 2.
- inRaw := invocation.ParameterRawValues()
- invocation.Reply()
inRawLen := len(inRaw)
- method := invocation.MethodName()
- // todo(DMwangnima): process headers(metadata) passed in
+
if !ti.clientManager.isIDL {
switch callType {
case constant.CallUnary:
@@ -137,6 +136,69 @@
return &result
}
+// parseInvocation retrieves information from invocation.
+// it returns ctx, callType, inRaw, method, error
+func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) {
+ callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
+ if !ok {
+ return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
+ }
+ callType, ok := callTypeRaw.(string)
+ if !ok {
+ return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
+ }
+ // please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen
+ // e.g. Client.CallUnary(... req, resp []interface, ...)
+ // inRaw represents req and resp
+ inRaw := invocation.ParameterRawValues()
+ method := invocation.MethodName()
+ if method == "" {
+ return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
+ }
+
+ ctx, err := parseAttachments(ctx, url, invocation)
+ if err != nil {
+ return nil, "", nil, "", err
+ }
+
+ return ctx, callType, inRaw, method, nil
+}
+
+// parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx
+func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) {
+ // retrieve users passed-in attachment
+ attaRaw := ctx.Value(constant.AttachmentKey)
+ if attaRaw != nil {
+ if userAtta, ok := attaRaw.(map[string]interface{}); ok {
+ for key, val := range userAtta {
+ invocation.SetAttachment(key, val)
+ }
+ }
+ }
+ // set pre-defined attachments
+ for _, key := range triAttachmentKeys {
+ if val := url.GetParam(key, ""); len(val) > 0 {
+ invocation.SetAttachment(key, val)
+ }
+ }
+ // inject attachments
+ for key, valRaw := range invocation.Attachments() {
+ if str, ok := valRaw.(string); ok {
+ ctx = tri.AppendToOutgoingContext(ctx, key, str)
+ continue
+ }
+ if strs, ok := valRaw.([]string); ok {
+ for _, str := range strs {
+ ctx = tri.AppendToOutgoingContext(ctx, key, str)
+ }
+ continue
+ }
+ return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
+ }
+
+ return ctx, nil
+}
+
// IsAvailable get available status
func (ti *TripleInvoker) IsAvailable() bool {
if ti.getClientManager() != nil {
diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go
new file mode 100644
index 0000000..7d14d42
--- /dev/null
+++ b/protocol/triple/triple_invoker_test.go
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple
+
+import (
+ "context"
+ "testing"
+
+ "dubbo.apache.org/dubbo-go/v3/common"
+ "dubbo.apache.org/dubbo-go/v3/common/constant"
+ "dubbo.apache.org/dubbo-go/v3/protocol"
+ "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+ tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_parseInvocation(t *testing.T) {
+ tests := []struct {
+ desc string
+ ctx func() context.Context
+ url *common.URL
+ invo func() protocol.Invocation
+ expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error)
+ }{
+ {
+ desc: "miss callType",
+ ctx: func() context.Context {
+ return context.Background()
+ },
+ url: common.NewURLWithOptions(),
+ invo: func() protocol.Invocation {
+ return invocation.NewRPCInvocationWithOptions()
+ },
+ expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ assert.NotNil(t, err)
+ },
+ },
+ {
+ desc: "wrong callType",
+ ctx: func() context.Context {
+ return context.Background()
+ },
+ url: common.NewURLWithOptions(),
+ invo: func() protocol.Invocation {
+ iv := invocation.NewRPCInvocationWithOptions()
+ iv.SetAttribute(constant.CallTypeKey, 1)
+ return iv
+ },
+ expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ assert.NotNil(t, err)
+ },
+ },
+ {
+ desc: "empty methodName",
+ ctx: func() context.Context {
+ return context.Background()
+ },
+ url: common.NewURLWithOptions(),
+ invo: func() protocol.Invocation {
+ iv := invocation.NewRPCInvocationWithOptions()
+ iv.SetAttribute(constant.CallTypeKey, constant.CallUnary)
+ return iv
+ },
+ expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ assert.NotNil(t, err)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
+ test.expect(t, ctx, callType, inRaw, methodName, err)
+ })
+ }
+}
+
+func Test_parseAttachments(t *testing.T) {
+ tests := []struct {
+ desc string
+ ctx func() context.Context
+ url *common.URL
+ invo func() protocol.Invocation
+ expect func(t *testing.T, ctx context.Context, err error)
+ }{
+ {
+ desc: "url has pre-defined keys in triAttachmentKeys",
+ ctx: func() context.Context {
+ return context.Background()
+ },
+ url: common.NewURLWithOptions(
+ common.WithInterface("interface"),
+ common.WithToken("token"),
+ ),
+ invo: func() protocol.Invocation {
+ return invocation.NewRPCInvocationWithOptions()
+ },
+ expect: func(t *testing.T, ctx context.Context, err error) {
+ assert.Nil(t, err)
+ header := tri.ExtractFromOutgoingContext(ctx)
+ assert.NotNil(t, header)
+ assert.Equal(t, "interface", header.Get(constant.InterfaceKey))
+ assert.Equal(t, "token", header.Get(constant.TokenKey))
+ },
+ },
+ {
+ desc: "user passed-in legal attachments",
+ ctx: func() context.Context {
+ userDefined := make(map[string]interface{})
+ userDefined["key1"] = "val1"
+ userDefined["key2"] = []string{"key2_1", "key2_2"}
+ return context.WithValue(context.Background(), constant.AttachmentKey, userDefined)
+ },
+ url: common.NewURLWithOptions(),
+ invo: func() protocol.Invocation {
+ return invocation.NewRPCInvocationWithOptions()
+ },
+ expect: func(t *testing.T, ctx context.Context, err error) {
+ assert.Nil(t, err)
+ header := tri.ExtractFromOutgoingContext(ctx)
+ assert.NotNil(t, header)
+ assert.Equal(t, "val1", header.Get("key1"))
+ assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2"))
+ },
+ },
+ {
+ desc: "user passed-in illegal attachments",
+ ctx: func() context.Context {
+ userDefined := make(map[string]interface{})
+ userDefined["key1"] = 1
+ return context.WithValue(context.Background(), constant.AttachmentKey, userDefined)
+ },
+ url: common.NewURLWithOptions(),
+ invo: func() protocol.Invocation {
+ return invocation.NewRPCInvocationWithOptions()
+ },
+ expect: func(t *testing.T, ctx context.Context, err error) {
+ assert.NotNil(t, err)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ ctx, err := parseAttachments(test.ctx(), test.url, test.invo())
+ test.expect(t, ctx, err)
+ })
+ }
+}
diff --git a/protocol/triple/triple_protocol/client.go b/protocol/triple/triple_protocol/client.go
index 4d49a6f..8ba67ee 100644
--- a/protocol/triple/triple_protocol/client.go
+++ b/protocol/triple/triple_protocol/client.go
@@ -124,6 +124,7 @@
if flag {
defer cancel()
}
+ mergeHeaders(request.Header(), ExtractFromOutgoingContext(ctx))
applyGroupVersionHeaders(request.Header(), c.config)
return c.callUnary(ctx, request, response)
}
@@ -170,6 +171,7 @@
func (c *Client) newConn(ctx context.Context, streamType StreamType) StreamingClientConn {
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
+ mergeHeaders(header, ExtractFromOutgoingContext(ctx))
applyGroupVersionHeaders(header, c.config)
c.protocolClient.WriteRequestHeader(streamType, header)
return c.protocolClient.NewConn(ctx, spec, header)
diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go
index 261a6ad..28618e0 100644
--- a/protocol/triple/triple_protocol/header.go
+++ b/protocol/triple/triple_protocol/header.go
@@ -126,6 +126,15 @@
return context.WithValue(ctx, headerOutgoingKey{}, header)
}
+func ExtractFromOutgoingContext(ctx context.Context) http.Header {
+ headerRaw := ctx.Value(headerOutgoingKey{})
+ if headerRaw == nil {
+ return nil
+ }
+ // since headerOutgoingKey is only used in triple_protocol package, we need not verify the type
+ return headerRaw.(http.Header)
+}
+
// FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext.
// Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1.
func FromIncomingContext(ctx context.Context) (http.Header, bool) {
diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go
index dd72625..b2d43ea 100644
--- a/protocol/triple/triple_test.go
+++ b/protocol/triple/triple_test.go
@@ -87,7 +87,6 @@
func (t *tripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
name := invocation.MethodName()
args := invocation.Arguments()
- // todo(DMwangnima): use map to represent Methods
for _, method := range t.info.Methods {
if method.Name == name {
res, err := method.MethodFunc(ctx, args, t.handler)
@@ -124,7 +123,6 @@
func runOldTripleServer(interfaceName string, group string, version string, addr string, desc *grpc_go.ServiceDesc, svc common.RPCService) {
url := common.NewURLWithOptions(
- // todo(DMwangnima): figure this out
common.WithPath(interfaceName),
common.WithLocation(addr),
common.WithPort(dubbo3Port),