fix: resolve triple bugs based on dubbo-go-samples (#2545)
diff --git a/go.mod b/go.mod
index 27af2cc..630870e 100644
--- a/go.mod
+++ b/go.mod
@@ -45,7 +45,6 @@
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mitchellh/mapstructure v1.5.0
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd
- github.com/nacos-group/nacos-sdk-go v1.0.9 // indirect
github.com/nacos-group/nacos-sdk-go/v2 v2.2.2
github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852
github.com/opentracing/opentracing-go v1.2.0
@@ -57,6 +56,7 @@
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.8.3
+ github.com/ugorji/go/codec v1.2.6
go.etcd.io/etcd/api/v3 v3.5.7
go.etcd.io/etcd/client/v2 v2.305.0 // indirect
go.etcd.io/etcd/client/v3 v3.5.7
diff --git a/go.sum b/go.sum
index a181a39..6e9903f 100644
--- a/go.sum
+++ b/go.sum
@@ -1004,9 +1004,8 @@
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
+github.com/nacos-group/nacos-sdk-go v1.0.8 h1:8pEm05Cdav9sQgJSv5kyvlgfz0SzFUUGI3pWX6SiSnM=
github.com/nacos-group/nacos-sdk-go v1.0.8/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA=
-github.com/nacos-group/nacos-sdk-go v1.0.9 h1:sMvrp6tZj4LdhuHRsS4GCqASB81k3pjmT2ykDQQpwt0=
-github.com/nacos-group/nacos-sdk-go v1.0.9/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA=
github.com/nacos-group/nacos-sdk-go/v2 v2.1.2/go.mod h1:ys/1adWeKXXzbNWfRNbaFlX/t6HVLWdpsNDvmoWTw0g=
github.com/nacos-group/nacos-sdk-go/v2 v2.2.2 h1:FI+7vr1fvCA4jbgx36KezmP3zlU/WoP/7wAloaSd1Ew=
github.com/nacos-group/nacos-sdk-go/v2 v2.2.2/go.mod h1:ys/1adWeKXXzbNWfRNbaFlX/t6HVLWdpsNDvmoWTw0g=
diff --git a/protocol/triple/client.go b/protocol/triple/client.go
index 18b339e..efbc8be 100644
--- a/protocol/triple/client.go
+++ b/protocol/triple/client.go
@@ -47,6 +47,7 @@
// callUnary, callClientStream, callServerStream, callBidiStream.
// A Reference has a clientManager.
type clientManager struct {
+ isIDL bool
// triple_protocol clients, key is method name
triClients map[string]*tri.Client
}
@@ -115,10 +116,8 @@
return nil
}
+// newClientManager extracts configurations from url and builds clientManager
func newClientManager(url *common.URL) (*clientManager, error) {
- // If global trace instance was set, it means trace function enabled.
- // If not, will return NoopTracer.
- // tracer := opentracing.GlobalTracer()
var cliOpts []tri.ClientOption
// set max send and recv msg size
@@ -133,12 +132,18 @@
}
cliOpts = append(cliOpts, tri.WithSendMaxBytes(maxCallSendMsgSize))
+ var isIDL bool
// set serialization
serialization := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
switch serialization {
case constant.ProtobufSerialization:
+ isIDL = true
case constant.JSONSerialization:
cliOpts = append(cliOpts, tri.WithProtoJSON())
+ case constant.Hessian2Serialization:
+ cliOpts = append(cliOpts, tri.WithHessian2())
+ case constant.MsgpackSerialization:
+ cliOpts = append(cliOpts, tri.WithMsgPack())
default:
panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
}
@@ -147,42 +152,16 @@
timeout := url.GetParamDuration(constant.TimeoutKey, "")
cliOpts = append(cliOpts, tri.WithTimeout(timeout))
+ // set service group and version
group := url.GetParam(constant.GroupKey, "")
version := url.GetParam(constant.VersionKey, "")
cliOpts = append(cliOpts, tri.WithGroup(group), tri.WithVersion(version))
- // dialOpts = append(dialOpts,
- //
- // grpc.WithBlock(),
- // // todo config tracing
- // grpc.WithTimeout(time.Second*3),
- // grpc.WithUnaryInterceptor(otgrpc.OpenTracingClientInterceptor(tracer, otgrpc.LogPayloads())),
- // grpc.WithStreamInterceptor(otgrpc.OpenTracingStreamClientInterceptor(tracer, otgrpc.LogPayloads())),
- // grpc.WithDefaultCallOptions(
- // grpc.CallContentSubtype(clientConf.ContentSubType),
- // grpc.MaxCallRecvMsgSize(maxCallRecvMsgSize),
- // grpc.MaxCallSendMsgSize(maxCallSendMsgSize),
- // ),
- //
- // )
+ // todo(DMwangnima): support opentracing
+
+ // todo(DMwangnima): support TLS in an ideal way
var cfg *tls.Config
var tlsFlag bool
- //var err error
-
- // todo: think about a more elegant way to configure tls
- //if tlsConfig := config.GetRootConfig().TLSConfig; tlsConfig != nil {
- // cfg, err = config.GetClientTlsConfig(&config.TLSConfig{
- // CACertFile: tlsConfig.CACertFile,
- // TLSCertFile: tlsConfig.TLSCertFile,
- // TLSKeyFile: tlsConfig.TLSKeyFile,
- // TLSServerName: tlsConfig.TLSServerName,
- // })
- // if err != nil {
- // return nil, err
- // }
- // logger.Infof("TRIPLE clientManager initialized the TLSConfig configuration successfully")
- // tlsFlag = true
- //}
var transport http.RoundTripper
callType := url.GetParam(constant.CallHTTPTypeKey, constant.CallHTTP2)
@@ -231,6 +210,7 @@
}
return &clientManager{
+ isIDL: isIDL,
triClients: triClients,
}, nil
}
diff --git a/protocol/triple/codec.go b/protocol/triple/codec.go
deleted file mode 100644
index 3c6e851..0000000
--- a/protocol/triple/codec.go
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package triple
-
-import (
- "bytes"
- "encoding/json"
-)
-
-import (
- "github.com/golang/protobuf/jsonpb"
- "github.com/golang/protobuf/proto"
-
- grpcEncoding "google.golang.org/grpc/encoding"
-)
-
-const (
- codecJson = "json"
- codecProto = "proto"
-)
-
-func init() {
- grpcEncoding.RegisterCodec(grpcJson{
- Marshaler: jsonpb.Marshaler{
- EmitDefaults: true,
- OrigName: true,
- },
- })
-}
-
-type grpcJson struct {
- jsonpb.Marshaler
- jsonpb.Unmarshaler
-}
-
-// Name implements grpc encoding package Codec interface method,
-// returns the name of the Codec implementation.
-func (_ grpcJson) Name() string {
- return codecJson
-}
-
-// Marshal implements grpc encoding package Codec interface method,returns the wire format of v.
-func (j grpcJson) Marshal(v interface{}) (out []byte, err error) {
- if pm, ok := v.(proto.Message); ok {
- b := new(bytes.Buffer)
- err := j.Marshaler.Marshal(b, pm)
- if err != nil {
- return nil, err
- }
- return b.Bytes(), nil
- }
- return json.Marshal(v)
-}
-
-// Unmarshal implements grpc encoding package Codec interface method,Unmarshal parses the wire format into v.
-func (j grpcJson) Unmarshal(data []byte, v interface{}) (err error) {
- if pm, ok := v.(proto.Message); ok {
- b := bytes.NewBuffer(data)
- return j.Unmarshaler.Unmarshal(b, pm)
- }
- return json.Unmarshal(data, v)
-}
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index fd6c6b9..7a2dc7d 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -20,8 +20,8 @@
import (
"context"
"fmt"
+ "reflect"
"sync"
- "time"
)
import (
@@ -44,14 +44,16 @@
"dubbo.apache.org/dubbo-go/v3/server"
)
-// Server is TRIPLE server
+// Server is TRIPLE adaptation layer representation. It makes use of tri.Server to
+// provide functionality.
type Server struct {
triServer *tri.Server
- services map[string]grpc.ServiceInfo
mu sync.RWMutex
+ services map[string]grpc.ServiceInfo
}
-// NewServer creates a new TRIPLE server
+// NewServer creates a new TRIPLE server.
+// triServer would not be initialized since we could not get configurations here.
func NewServer() *Server {
return &Server{
services: make(map[string]grpc.ServiceInfo),
@@ -60,55 +62,35 @@
// Start TRIPLE server
func (s *Server) Start(invoker protocol.Invoker, info *server.ServiceInfo) {
- var (
- addr string
- URL *common.URL
- hanOpts []tri.HandlerOption
- )
- URL = invoker.GetURL()
- addr = URL.Location
+ URL := invoker.GetURL()
+ addr := URL.Location
+ // initialize tri.Server
s.triServer = tri.NewServer(addr)
+
serialization := URL.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
switch serialization {
case constant.ProtobufSerialization:
case constant.JSONSerialization:
+ case constant.Hessian2Serialization:
+ case constant.MsgpackSerialization:
default:
panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
}
- // todo: implement interceptor
- // If global trace instance was set, then server tracer instance
- // can be get. If not, will return NoopTracer.
- //tracer := opentracing.GlobalTracer()
- //serverOpts = append(serverOpts,
- // grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(tracer)),
- // grpc.StreamInterceptor(otgrpc.OpenTracingStreamServerInterceptor(tracer)),
- // grpc.MaxRecvMsgSize(maxServerRecvMsgSize),
- // grpc.MaxSendMsgSize(maxServerSendMsgSize),
- //)
- //var cfg *tls.Config
+ // todo: support opentracing interceptor
+
// todo(DMwangnima): think about a more elegant way to configure tls
- //tlsConfig := config.GetRootConfig().TLSConfig
- //if tlsConfig != nil {
- // cfg, err = config.GetServerTlsConfig(&config.TLSConfig{
- // CACertFile: tlsConfig.CACertFile,
- // TLSCertFile: tlsConfig.TLSCertFile,
- // TLSKeyFile: tlsConfig.TLSKeyFile,
- // TLSServerName: tlsConfig.TLSServerName,
- // })
- // if err != nil {
- // return
- // }
- // logger.Infof("Triple Server initialized the TLSConfig configuration")
- //}
- //srv.TLSConfig = cfg
+
// todo:// move tls config to handleService
- hanOpts = getHanOpts(URL)
+ hanOpts := getHanOpts(URL)
+ intfName := URL.Interface()
if info != nil {
- s.handleServiceWithInfo(invoker, info, hanOpts...)
- s.saveServiceInfo(info)
+ // new triple idl mode
+ s.handleServiceWithInfo(intfName, invoker, info, hanOpts...)
+ s.saveServiceInfo(intfName, info)
} else {
- s.compatHandleService(URL, hanOpts...)
+ // old triple idl mode and non-idl mode
+ s.compatHandleService(intfName, URL.Group(), URL.Version(), hanOpts...)
}
reflection.Register(s)
@@ -119,26 +101,26 @@
}()
}
+// todo(DMwangnima): extract a common function
// RefreshService refreshes Triple Service
func (s *Server) RefreshService(invoker protocol.Invoker, info *server.ServiceInfo) {
- var (
- URL *common.URL
- hanOpts []tri.HandlerOption
- )
- URL = invoker.GetURL()
+ URL := invoker.GetURL()
serialization := URL.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
switch serialization {
case constant.ProtobufSerialization:
case constant.JSONSerialization:
+ case constant.Hessian2Serialization:
+ case constant.MsgpackSerialization:
default:
panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
}
- hanOpts = getHanOpts(URL)
+ hanOpts := getHanOpts(URL)
+ intfName := URL.Interface()
if info != nil {
- s.handleServiceWithInfo(invoker, info, hanOpts...)
- s.saveServiceInfo(info)
+ s.handleServiceWithInfo(intfName, invoker, info, hanOpts...)
+ s.saveServiceInfo(intfName, info)
} else {
- s.compatHandleService(URL, hanOpts...)
+ s.compatHandleService(intfName, URL.Group(), URL.Version(), hanOpts...)
}
}
@@ -157,7 +139,6 @@
hanOpts = append(hanOpts, tri.WithSendMaxBytes(maxServerSendMsgSize))
// todo:// open tracing
- hanOpts = append(hanOpts, tri.WithInterceptors())
group := url.GetParam(constant.GroupKey, "")
version := url.GetParam(constant.VersionKey, "")
@@ -165,90 +146,58 @@
return hanOpts
}
-// getSyncMapLen gets sync map len
-func getSyncMapLen(m *sync.Map) int {
- length := 0
-
- m.Range(func(_, _ interface{}) bool {
- length++
- return true
- })
- return length
-}
-
-// waitTripleExporter wait until len(providerServices) = len(ExporterMap)
-func waitTripleExporter(providerServices map[string]*config.ServiceConfig) {
- t := time.NewTicker(50 * time.Millisecond)
- defer t.Stop()
- pLen := len(providerServices)
- ta := time.NewTimer(10 * time.Second)
- defer ta.Stop()
-
- for {
- select {
- case <-t.C:
- mLen := getSyncMapLen(tripleProtocol.ExporterMap())
- if pLen == mLen {
- return
- }
- case <-ta.C:
- panic("wait Triple exporter timeout when start GRPC_NEW server")
- }
- }
-}
-
-// *Important*, this function is responsible for being compatible with old triple-gen code
+// *Important*, this function is responsible for being compatible with old triple-gen code and non-idl code
// compatHandleService registers handler based on ServiceConfig and provider service.
-func (s *Server) compatHandleService(url *common.URL, opts ...tri.HandlerOption) {
+func (s *Server) compatHandleService(interfaceName string, group, version string, opts ...tri.HandlerOption) {
providerServices := config.GetProviderConfig().Services
if len(providerServices) == 0 {
- logger.Info("Provider service map is null")
+ logger.Info("Provider service map is null, please register ProviderServices")
+ return
}
- //waitTripleExporter(providerServices)
for key, providerService := range providerServices {
- if providerService.Interface != url.Interface() {
+ if providerService.Interface != interfaceName || providerService.Group != group || providerService.Version != version {
continue
}
// todo(DMwangnima): judge protocol type
service := config.GetProviderService(key)
- ds, ok := service.(dubbo3.Dubbo3GrpcService)
- if !ok {
- panic("illegal service type registered")
- }
-
serviceKey := common.ServiceKey(providerService.Interface, providerService.Group, providerService.Version)
exporter, _ := tripleProtocol.ExporterMap().Load(serviceKey)
if exporter == nil {
- // todo(DMwangnima): handler reflection Service and health Service
+ logger.Warnf("no exporter found for serviceKey: %v", serviceKey)
continue
- //panic(fmt.Sprintf("no exporter found for servicekey: %v", serviceKey))
}
invoker := exporter.(protocol.Exporter).GetInvoker()
if invoker == nil {
panic(fmt.Sprintf("no invoker found for servicekey: %v", serviceKey))
}
+ ds, ok := service.(dubbo3.Dubbo3GrpcService)
+ if !ok {
+ info := createServiceInfoWithReflection(service)
+ s.handleServiceWithInfo(interfaceName, invoker, info, opts...)
+ continue
+ }
// inject invoker, it has all invocation logics
ds.XXX_SetProxyImpl(invoker)
- s.compatRegisterHandler(ds, opts...)
+ s.compatRegisterHandler(interfaceName, ds, opts...)
}
}
-func (s *Server) compatRegisterHandler(svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) {
+func (s *Server) compatRegisterHandler(interfaceName string, svc dubbo3.Dubbo3GrpcService, opts ...tri.HandlerOption) {
desc := svc.XXX_ServiceDesc()
// init unary handlers
for _, method := range desc.Methods {
// please refer to protocol/triple/internal/proto/triple_gen/greettriple for procedure examples
// error could be ignored because base is empty string
- procedure := joinProcedure(desc.ServiceName, method.MethodName)
- _ = s.triServer.RegisterCompatUnaryHandler(procedure, svc, tri.MethodHandler(method.Handler), opts...)
+ procedure := joinProcedure(interfaceName, method.MethodName)
+ _ = s.triServer.RegisterCompatUnaryHandler(procedure, method.MethodName, svc, tri.MethodHandler(method.Handler), opts...)
}
// init stream handlers
for _, stream := range desc.Streams {
// please refer to protocol/triple/internal/proto/triple_gen/greettriple for procedure examples
// error could be ignored because base is empty string
- procedure := joinProcedure(desc.ServiceName, stream.StreamName)
+ procedure := joinProcedure(interfaceName, stream.StreamName)
var typ tri.StreamType
switch {
case stream.ClientStreams && stream.ServerStreams:
@@ -263,10 +212,10 @@
}
// handleServiceWithInfo injects invoker and create handler based on ServiceInfo
-func (s *Server) handleServiceWithInfo(invoker protocol.Invoker, info *server.ServiceInfo, opts ...tri.HandlerOption) {
+func (s *Server) handleServiceWithInfo(interfaceName string, invoker protocol.Invoker, info *server.ServiceInfo, opts ...tri.HandlerOption) {
for _, method := range info.Methods {
m := method
- procedure := joinProcedure(info.InterfaceName, method.Name)
+ procedure := joinProcedure(interfaceName, method.Name)
switch m.Type {
case constant.CallUnary:
_ = s.triServer.RegisterUnaryHandler(
@@ -274,12 +223,28 @@
m.ReqInitFunc,
func(ctx context.Context, req *tri.Request) (*tri.Response, error) {
var args []interface{}
- args = append(args, req.Msg)
+ if argsRaw, ok := req.Msg.([]interface{}); ok {
+ // non-idl mode, req.Msg consists of many arguments
+ for _, argRaw := range argsRaw {
+ // refer to createServiceInfoWithReflection, in ReqInitFunc, argRaw is a pointer to real arg.
+ // so we have to invoke Elem to get the real arg.
+ args = append(args, reflect.ValueOf(argRaw).Elem().Interface())
+ }
+ } else {
+ // triple idl mode and old triple idl mode
+ args = append(args, req.Msg)
+ }
// todo: inject method.Meta to attachments
invo := invocation.NewRPCInvocation(m.Name, args, nil)
res := invoker.Invoke(ctx, invo)
- // todo(DMwangnima): if we do not use MethodInfo.MethodFunc, create Response manually
- return res.Result().(*tri.Response), res.Error()
+ // todo(DMwangnima): modify InfoInvoker to get a unified processing logic
+ // please refer to server/InfoInvoker.Invoke()
+ if triResp, ok := res.Result().(*tri.Response); ok {
+ return triResp, res.Error()
+ }
+ // please refer to proxy/proxy_factory/ProxyInvoker.Invoke
+ triResp := tri.NewResponse([]interface{}{res.Result()})
+ return triResp, res.Error()
},
opts...,
)
@@ -324,7 +289,7 @@
}
}
-func (s *Server) saveServiceInfo(info *server.ServiceInfo) {
+func (s *Server) saveServiceInfo(interfaceName string, info *server.ServiceInfo) {
ret := grpc.ServiceInfo{}
ret.Methods = make([]grpc.MethodInfo, 0, len(info.Methods))
for _, method := range info.Methods {
@@ -349,7 +314,8 @@
ret.Metadata = info
s.mu.Lock()
defer s.mu.Unlock()
- s.services[info.InterfaceName] = ret
+ // todo(DMwangnima): using interfaceName is not enough, we need to consider group and version
+ s.services[interfaceName] = ret
}
func (s *Server) GetServiceInfo() map[string]grpc.ServiceInfo {
@@ -371,3 +337,47 @@
func (s *Server) GracefulStop() {
_ = s.triServer.GracefulStop(context.Background())
}
+
+// createServiceInfoWithReflection is for non-idl scenario.
+// It makes use of reflection to extract method parameters information and create ServiceInfo.
+// As a result, Server could use this ServiceInfo to register.
+func createServiceInfoWithReflection(svc common.RPCService) *server.ServiceInfo {
+ var info server.ServiceInfo
+ val := reflect.ValueOf(svc)
+ typ := reflect.TypeOf(svc)
+ methodNum := val.NumMethod()
+ methodInfos := make([]server.MethodInfo, methodNum)
+ for i := 0; i < methodNum; i++ {
+ methodType := typ.Method(i)
+ if methodType.Name == "Reference" {
+ continue
+ }
+ paramsNum := methodType.Type.NumIn()
+ // the first param is receiver itself, the second param is ctx
+ // just ignore them
+ if paramsNum < 2 {
+ logger.Fatalf("TRIPLE does not support %s method that does not have any parameter", methodType.Name)
+ continue
+ }
+ paramsTypes := make([]reflect.Type, paramsNum-2)
+ for j := 2; j < paramsNum; j++ {
+ paramsTypes[j-2] = methodType.Type.In(j)
+ }
+ methodInfo := server.MethodInfo{
+ Name: methodType.Name,
+ // only support Unary invocation now
+ Type: constant.CallUnary,
+ ReqInitFunc: func() interface{} {
+ params := make([]interface{}, len(paramsTypes))
+ for k, paramType := range paramsTypes {
+ params[k] = reflect.New(paramType).Interface()
+ }
+ return params
+ },
+ }
+ methodInfos[i] = methodInfo
+ }
+ info.Methods = methodInfos
+
+ return &info
+}
diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go
index 6f41520..74bd772 100644
--- a/protocol/triple/triple.go
+++ b/protocol/triple/triple.go
@@ -80,6 +80,7 @@
tp.SetExporterMap(serviceKey, exporter)
logger.Infof("[TRIPLE Protocol] Export service: %s", url.String())
tp.openServer(invoker, info)
+ health.SetServingStatusServing(url.Service())
return exporter
}
@@ -98,8 +99,8 @@
}
srv := NewServer()
- tp.serverMap[url.Location] = srv
srv.Start(invoker, info)
+ tp.serverMap[url.Location] = srv
}
// Refer a remote triple service
diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go
index 525fae6..2e83765 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -40,18 +40,18 @@
clientManager *clientManager
}
-func (gni *TripleInvoker) setClientManager(cm *clientManager) {
- gni.clientGuard.Lock()
- defer gni.clientGuard.Unlock()
+func (ti *TripleInvoker) setClientManager(cm *clientManager) {
+ ti.clientGuard.Lock()
+ defer ti.clientGuard.Unlock()
- gni.clientManager = cm
+ ti.clientManager = cm
}
-func (gni *TripleInvoker) getClientManager() *clientManager {
- gni.clientGuard.RLock()
- defer gni.clientGuard.RUnlock()
+func (ti *TripleInvoker) getClientManager() *clientManager {
+ ti.clientGuard.RLock()
+ defer ti.clientGuard.RUnlock()
- return gni.clientManager
+ return ti.clientManager
}
// Invoke is used to call client-side method.
@@ -89,6 +89,18 @@
inRawLen := len(inRaw)
method := invocation.MethodName()
// todo(DMwangnima): process headers(metadata) passed in
+ if !ti.clientManager.isIDL {
+ switch callType {
+ case constant.CallUnary:
+ // todo(DMwangnima): consider inRawLen == 0
+ if err := ti.clientManager.callUnary(ctx, method, inRaw[0:inRawLen-1], inRaw[inRawLen-1]); err != nil {
+ result.SetError(err)
+ }
+ default:
+ panic("Triple only supports Unary Invocation for Non-IDL mode")
+ }
+ return &result
+ }
switch callType {
case constant.CallUnary:
if len(inRaw) != 2 {
@@ -96,13 +108,11 @@
}
if err := ti.clientManager.callUnary(ctx, method, inRaw[0], inRaw[1]); err != nil {
result.SetError(err)
- return &result
}
case constant.CallClientStream:
stream, err := ti.clientManager.callClientStream(ctx, method)
if err != nil {
result.SetError(err)
- return &result
}
result.SetResult(stream)
case constant.CallServerStream:
@@ -112,14 +122,12 @@
stream, err := ti.clientManager.callServerStream(ctx, method, inRaw[0])
if err != nil {
result.Err = err
- return &result
}
result.SetResult(stream)
case constant.CallBidiStream:
stream, err := ti.clientManager.callBidiStream(ctx, method)
if err != nil {
result.Err = err
- return &result
}
result.SetResult(stream)
default:
@@ -130,18 +138,18 @@
}
// IsAvailable get available status
-func (gni *TripleInvoker) IsAvailable() bool {
- if gni.getClientManager() != nil {
- return gni.BaseInvoker.IsAvailable()
+func (ti *TripleInvoker) IsAvailable() bool {
+ if ti.getClientManager() != nil {
+ return ti.BaseInvoker.IsAvailable()
}
return false
}
// IsDestroyed get destroyed status
-func (gni *TripleInvoker) IsDestroyed() bool {
- if gni.getClientManager() != nil {
- return gni.BaseInvoker.IsDestroyed()
+func (ti *TripleInvoker) IsDestroyed() bool {
+ if ti.getClientManager() != nil {
+ return ti.BaseInvoker.IsDestroyed()
}
return false
diff --git a/protocol/triple/triple_protocol/codec.go b/protocol/triple/triple_protocol/codec.go
index 8e7d499..ff131f6 100644
--- a/protocol/triple/triple_protocol/codec.go
+++ b/protocol/triple/triple_protocol/codec.go
@@ -19,6 +19,10 @@
"encoding/json"
"errors"
"fmt"
+ hessian "github.com/apache/dubbo-go-hessian2"
+ "github.com/dubbogo/grpc-go/encoding"
+ "github.com/dubbogo/grpc-go/encoding/proto_wrapper_api"
+ "github.com/dubbogo/grpc-go/encoding/tools"
)
import (
@@ -27,11 +31,15 @@
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoiface"
+
+ msgpack "github.com/ugorji/go/codec"
)
const (
codecNameProto = "proto"
codecNameJSON = "json"
+ codecNameHessian2 = "hessian2"
+ codecNameMsgPack = "msgpack"
codecNameJSONCharsetUTF8 = codecNameJSON + "; charset=utf-8"
)
@@ -173,6 +181,111 @@
return false
}
+// todo(DMwangnima): add unit tests
+type protoWrapperCodec struct {
+ innerCodec Codec
+}
+
+func (c *protoWrapperCodec) Name() string {
+ return c.innerCodec.Name()
+}
+
+func (c *protoWrapperCodec) Marshal(message interface{}) ([]byte, error) {
+ reqs, ok := message.([]interface{})
+ if !ok {
+ return c.innerCodec.Marshal(message)
+ }
+ reqsLen := len(reqs)
+ reqsBytes := make([][]byte, reqsLen)
+ reqsTypes := make([]string, reqsLen)
+ for i, req := range reqs {
+ reqBytes, err := c.innerCodec.Marshal(req)
+ if err != nil {
+ return nil, err
+ }
+ reqsBytes[i] = reqBytes
+ reqsTypes[i] = encoding.GetArgType(req)
+ }
+
+ wrapperReq := &proto_wrapper_api.TripleRequestWrapper{
+ SerializeType: c.innerCodec.Name(),
+ Args: reqsBytes,
+ ArgTypes: reqsTypes,
+ }
+
+ return proto.Marshal(wrapperReq)
+}
+
+func (c *protoWrapperCodec) Unmarshal(binary []byte, message interface{}) error {
+ params, ok := message.([]interface{})
+ if !ok {
+ return c.innerCodec.Unmarshal(binary, message)
+ }
+
+ var wrapperReq proto_wrapper_api.TripleRequestWrapper
+ if err := proto.Unmarshal(binary, &wrapperReq); err != nil {
+ return err
+ }
+ if len(wrapperReq.Args) != len(params) {
+ return fmt.Errorf("error, request params len is %d, but has %d actually", len(wrapperReq.Args), len(params))
+ }
+
+ for i, arg := range wrapperReq.Args {
+ if err := c.innerCodec.Unmarshal(arg, params[i]); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func newProtoWrapperCodec(innerCodec Codec) *protoWrapperCodec {
+ return &protoWrapperCodec{innerCodec: innerCodec}
+}
+
+// todo(DMwangnima): add unit tests
+type hessian2Codec struct{}
+
+func (h *hessian2Codec) Name() string {
+ return codecNameHessian2
+}
+
+func (c *hessian2Codec) Marshal(message interface{}) ([]byte, error) {
+ encoder := hessian.NewEncoder()
+ if err := encoder.Encode(message); err != nil {
+ return nil, err
+ }
+
+ return encoder.Buffer(), nil
+}
+
+func (c *hessian2Codec) Unmarshal(binary []byte, message interface{}) error {
+ decoder := hessian.NewDecoder(binary)
+ val, err := decoder.Decode()
+ if err != nil {
+ return err
+ }
+ return tools.ReflectResponse(val, message)
+}
+
+// todo(DMwangnima): add unit tests
+type msgpackCodec struct{}
+
+func (c *msgpackCodec) Name() string {
+ return codecNameMsgPack
+}
+
+func (c *msgpackCodec) Marshal(message interface{}) ([]byte, error) {
+ var out []byte
+ encoder := msgpack.NewEncoderBytes(&out, new(msgpack.MsgpackHandle))
+ return out, encoder.Encode(message)
+}
+
+func (c *msgpackCodec) Unmarshal(binary []byte, message interface{}) error {
+ decoder := msgpack.NewDecoderBytes(binary, new(msgpack.MsgpackHandle))
+ return decoder.Decode(message)
+}
+
// readOnlyCodecs is a read-only interface to a map of named codecs.
type readOnlyCodecs interface {
// Get gets the Codec with the given name.
diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go
index 3c02ee6..56f83b3 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -340,11 +340,12 @@
_ = connCloser.Close(timeoutErr)
return
}
+
// invoke implementation
svcGroup := request.Header.Get(tripleServiceGroup)
svcVersion := request.Header.Get(tripleServiceVersion)
- implementation := h.implementations[getIdentifier(svcGroup, svcVersion)]
// todo(DMwangnima): inspect ok
+ implementation := h.implementations[getIdentifier(svcGroup, svcVersion)]
_ = connCloser.Close(implementation(ctx, connCloser))
}
@@ -376,6 +377,8 @@
}
withProtoBinaryCodec().applyToHandler(&config)
withProtoJSONCodecs().applyToHandler(&config)
+ withHessian2Codec().applyToHandler(&config)
+ withMsgPackCodec().applyToHandler(&config)
withGzip().applyToHandler(&config)
for _, opt := range options {
opt.applyToHandler(&config)
diff --git a/protocol/triple/triple_protocol/handler_compat.go b/protocol/triple/triple_protocol/handler_compat.go
index 7c21993..89d14de 100644
--- a/protocol/triple/triple_protocol/handler_compat.go
+++ b/protocol/triple/triple_protocol/handler_compat.go
@@ -19,12 +19,16 @@
import (
"context"
+ "errors"
"fmt"
+ "github.com/dubbogo/grpc-go/metadata"
+ "github.com/golang/protobuf/proto"
"net/http"
)
import (
"github.com/dubbogo/grpc-go"
+ "github.com/dubbogo/grpc-go/status"
)
import (
@@ -56,18 +60,35 @@
if !ok {
return nil, errorf(CodeInternal, "unexpected handler request type %T", request)
}
- respRaw, err := handler(ctx, typed.Any())
- if respRaw == nil && err == nil {
+ dubbo3RespRaw, err := handler(ctx, typed.Any())
+ if dubbo3RespRaw == nil && err == nil {
// This is going to panic during serialization. Debugging is much easier
// if we panic here instead, so we can include the procedure name.
panic(fmt.Sprintf("%s returned nil resp and nil error", t.procedure)) //nolint: forbidigo
}
- resp, ok := respRaw.(*dubbo_protocol.RPCResult)
+ dubbo3Resp, ok := dubbo3RespRaw.(*dubbo_protocol.RPCResult)
if !ok {
- panic(fmt.Sprintf("%+v is not of type *RPCResult", respRaw))
+ panic(fmt.Sprintf("%+v is not of type *RPCResult", dubbo3RespRaw))
+ }
+ dubbo3Err, ok := compatError(err)
+ if ok {
+ err = dubbo3Err
}
// todo(DMwangnima): expose API for users to write response headers and trailers
- return NewResponse(resp.Rest), err
+ resp := NewResponse(dubbo3Resp.Rest)
+ trailer := make(http.Header)
+ for key, valRaw := range dubbo3Resp.Attachments() {
+ switch valRaw.(type) {
+ case string:
+ trailer[key] = []string{valRaw.(string)}
+ case []string:
+ trailer[key] = valRaw.([]string)
+ default:
+ panic(fmt.Sprintf("unsupported attachment value type %T", valRaw))
+ }
+ }
+ resp.trailer = trailer
+ return resp, err
}
if t.interceptor != nil {
@@ -79,12 +100,13 @@
func NewCompatUnaryHandler(
procedure string,
+ method string,
srv interface{},
unary MethodHandler,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
- implementation := generateCompatUnaryHandlerFunc(procedure, srv, unary, config.Interceptor)
+ implementation := generateCompatUnaryHandlerFunc(procedure, method, srv, unary, config.Interceptor)
protocolHandlers := config.newProtocolHandlers(StreamTypeUnary)
hdl := &Handler{
@@ -101,6 +123,7 @@
func generateCompatUnaryHandlerFunc(
procedure string,
+ method string,
srv interface{},
unary MethodHandler,
interceptor Interceptor,
@@ -119,9 +142,14 @@
}
return nil
}
+ ctx = metadata.NewIncomingContext(ctx, metadata.MD(conn.ExportableHeader()))
+ // staticcheck error: SA1029. dubbo3 code needs to make use of "XXX_TRIPLE_GO_METHOD_NAME"
+ //nolint:staticcheck
+ ctx = context.WithValue(ctx, "XXX_TRIPLE_GO_METHOD_NAME", method)
// staticcheck error: SA1029. Stub code generated by protoc-gen-go-triple makes use of "XXX_TRIPLE_GO_INTERFACE_NAME" directly
//nolint:staticcheck
ctx = context.WithValue(ctx, "XXX_TRIPLE_GO_INTERFACE_NAME", procedure)
+ // todo(DMwangnima): deal with XXX_TRIPLE_GO_GENERIC_PAYLOAD
respRaw, err := unary(srv, ctx, decodeFunc, compatInterceptor.compatUnaryServerInterceptor)
if err != nil {
return err
@@ -133,3 +161,25 @@
return conn.Send(resp.Any())
}
}
+
+func compatError(err error) (*Error, bool) {
+ if err == nil {
+ return nil, false
+ }
+ s, ok := status.FromError(err)
+ if !ok {
+ return nil, false
+ }
+
+ triErr := NewError(Code(s.Code()), errors.New(s.Message()))
+ for _, detail := range s.Details() {
+ // dubbo3 detail use MessageV1, we need to convert it to MessageV2
+ errDetail, e := NewErrorDetail(proto.MessageV2(detail.(proto.Message)))
+ if e != nil {
+ return nil, false
+ }
+ triErr.AddDetail(errDetail)
+ }
+
+ return triErr, ok
+}
diff --git a/protocol/triple/triple_protocol/handler_compat_test.go b/protocol/triple/triple_protocol/handler_compat_test.go
new file mode 100644
index 0000000..b6db682
--- /dev/null
+++ b/protocol/triple/triple_protocol/handler_compat_test.go
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+ "github.com/dubbogo/grpc-go/codes"
+ "github.com/dubbogo/grpc-go/status"
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+func TestCompatError(t *testing.T) {
+ err := status.Error(codes.Code(1234), "user defined")
+ triErr, ok := compatError(err)
+ assert.True(t, ok)
+ assert.Equal(t, Code(1234), triErr.Code())
+ assert.Equal(t, "user defined", triErr.Message())
+ assert.Equal(t, 1, len(triErr.Details()))
+}
diff --git a/protocol/triple/triple_protocol/handler_ext_test.go b/protocol/triple/triple_protocol/handler_ext_test.go
index 80f0a6e..90e6ef4 100644
--- a/protocol/triple/triple_protocol/handler_ext_test.go
+++ b/protocol/triple/triple_protocol/handler_ext_test.go
@@ -122,11 +122,15 @@
assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType)
assert.Equal(t, resp.Header.Get("Accept-Post"), strings.Join([]string{
"application/grpc",
+ "application/grpc+hessian2",
"application/grpc+json",
"application/grpc+json; charset=utf-8",
+ "application/grpc+msgpack",
"application/grpc+proto",
+ "application/hessian2",
"application/json",
"application/json; charset=utf-8",
+ "application/msgpack",
"application/proto",
}, ", "))
})
diff --git a/protocol/triple/triple_protocol/handler_stream.go b/protocol/triple/triple_protocol/handler_stream.go
index c537089..a92f4ef 100644
--- a/protocol/triple/triple_protocol/handler_stream.go
+++ b/protocol/triple/triple_protocol/handler_stream.go
@@ -146,6 +146,11 @@
return b.conn.RequestHeader()
}
+// ExportableHeader returns the headers could be exported to users.
+func (b *BidiStream) ExportableHeader() http.Header {
+ return b.conn.ExportableHeader()
+}
+
// Receive a message. When the client is done sending messages, Receive will
// return an error that wraps [io.EOF].
func (b *BidiStream) Receive(msg interface{}) error {
diff --git a/protocol/triple/triple_protocol/handler_stream_compat.go b/protocol/triple/triple_protocol/handler_stream_compat.go
index bf96992..6a0fbb3 100644
--- a/protocol/triple/triple_protocol/handler_stream_compat.go
+++ b/protocol/triple/triple_protocol/handler_stream_compat.go
@@ -50,6 +50,10 @@
return c.ctx
}
+func (c *compatHandlerStream) SetContext(ctx context.Context) {
+ c.ctx = ctx
+}
+
func (c *compatHandlerStream) SendMsg(m interface{}) error {
return c.conn.Send(m)
}
diff --git a/protocol/triple/triple_protocol/header_compat.go b/protocol/triple/triple_protocol/header_compat.go
new file mode 100644
index 0000000..3782c09
--- /dev/null
+++ b/protocol/triple/triple_protocol/header_compat.go
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+// These keys are for compatible usage
+const (
+ TripleContentType = "application/grpc+proto"
+ TripleUserAgent = "grpc-go/1.35.0-dev"
+ TripleServiceVersion = "tri-service-version"
+ TripleAttachement = "tri-attachment"
+ TripleServiceGroup = "tri-service-group"
+ TripleRequestID = "tri-req-id"
+ TripleTraceID = "tri-trace-traceid"
+ TripleTraceRPCID = "tri-trace-rpcid"
+ TripleTraceProtoBin = "tri-trace-proto-bin"
+ TripleUnitInfo = "tri-unit-info"
+)
diff --git a/protocol/triple/triple_protocol/option.go b/protocol/triple/triple_protocol/option.go
index 5cef7d1..8383bcd 100644
--- a/protocol/triple/triple_protocol/option.go
+++ b/protocol/triple/triple_protocol/option.go
@@ -76,6 +76,15 @@
return WithCodec(&protoJSONCodec{codecNameJSON})
}
+// todo(DMwangnima): add comment
+func WithHessian2() ClientOption {
+ return WithCodec(newProtoWrapperCodec(&hessian2Codec{}))
+}
+
+func WithMsgPack() ClientOption {
+ return WithCodec(newProtoWrapperCodec(&msgpackCodec{}))
+}
+
// WithSendCompression configures the client to use the specified algorithm to
// compress request messages. If the algorithm has not been registered using
// [WithAcceptCompression], the client will return errors at runtime.
@@ -545,3 +554,11 @@
WithCodec(&protoJSONCodec{codecNameJSONCharsetUTF8}),
)
}
+
+func withHessian2Codec() Option {
+ return WithCodec(newProtoWrapperCodec(&hessian2Codec{}))
+}
+
+func withMsgPackCodec() Option {
+ return WithCodec(newProtoWrapperCodec(&msgpackCodec{}))
+}
diff --git a/protocol/triple/triple_protocol/protocol_grpc.go b/protocol/triple/triple_protocol/protocol_grpc.go
index 80dff38..f52cf0e 100644
--- a/protocol/triple/triple_protocol/protocol_grpc.go
+++ b/protocol/triple/triple_protocol/protocol_grpc.go
@@ -461,6 +461,62 @@
return hc.request.Header
}
+func (hc *grpcHandlerConn) ExportableHeader() http.Header {
+ // todo(DMwangnima): check out whether res should be cached
+ res := make(http.Header)
+ hdr := hc.request.Header
+ for key, vals := range hdr {
+ key = strings.ToLower(key)
+ if isReservedHeader(key) && !isWhitelistedHeader(key) {
+ continue
+ }
+ cloneVals := make([]string, len(vals))
+ for i, val := range vals {
+ cloneVals[i] = val
+ }
+ res[key] = cloneVals
+ }
+
+ return res
+}
+
+// isReservedHeader checks whether hdr belongs to HTTP2 headers
+// reserved by gRPC protocol. Any other headers are classified as the
+// user-specified metadata.
+func isReservedHeader(hdr string) bool {
+ if hdr != "" && hdr[0] == ':' {
+ return true
+ }
+ switch hdr {
+ case "content-type",
+ "user-agent",
+ "grpc-message-type",
+ "grpc-encoding",
+ "grpc-message",
+ "grpc-status",
+ "grpc-timeout",
+ "grpc-status-details-bin",
+ // Intentionally exclude grpc-previous-rpc-attempts and
+ // grpc-retry-pushback-ms, which are "reserved", but their API
+ // intentionally works via metadata.
+ "te":
+ return true
+ default:
+ return false
+ }
+}
+
+// isWhitelistedHeader checks whether hdr should be propagated into metadata
+// visible to users, even though it is classified as "reserved", above.
+func isWhitelistedHeader(hdr string) bool {
+ switch hdr {
+ case ":authority", "user-agent":
+ return true
+ default:
+ return false
+ }
+}
+
func (hc *grpcHandlerConn) Send(msg interface{}) error {
defer flushResponseWriter(hc.responseWriter)
if !hc.wroteToBody {
diff --git a/protocol/triple/triple_protocol/protocol_triple.go b/protocol/triple/triple_protocol/protocol_triple.go
index 84a8b66..5a1207b 100644
--- a/protocol/triple/triple_protocol/protocol_triple.go
+++ b/protocol/triple/triple_protocol/protocol_triple.go
@@ -417,6 +417,11 @@
return hc.request.Header
}
+func (hc *tripleUnaryHandlerConn) ExportableHeader() http.Header {
+ // by now, there are no reserved headers
+ return hc.request.Header
+}
+
func (hc *tripleUnaryHandlerConn) Send(msg interface{}) error {
hc.wroteBody = true
hc.writeResponseHeader(nil /* error */)
diff --git a/protocol/triple/triple_protocol/server.go b/protocol/triple/triple_protocol/server.go
index 0c09c7b..26055be 100644
--- a/protocol/triple/triple_protocol/server.go
+++ b/protocol/triple/triple_protocol/server.go
@@ -32,6 +32,7 @@
type Server struct {
mu sync.Mutex
+ mux *http.ServeMux
handlers map[string]*Handler
httpSrv *http.Server
}
@@ -46,6 +47,7 @@
if !ok {
hdl = NewUnaryHandler(procedure, reqInitFunc, unary, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateUnaryHandlerFunc(procedure, reqInitFunc, unary, config.Interceptor)
@@ -64,6 +66,7 @@
if !ok {
hdl = NewClientStreamHandler(procedure, stream, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateClientStreamHandlerFunc(procedure, stream, config.Interceptor)
@@ -83,6 +86,7 @@
if !ok {
hdl = NewServerStreamHandler(procedure, reqInitFunc, stream, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateServerStreamHandlerFunc(procedure, reqInitFunc, stream, config.Interceptor)
@@ -101,6 +105,7 @@
if !ok {
hdl = NewBidiStreamHandler(procedure, stream, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateBidiStreamHandlerFunc(procedure, stream, config.Interceptor)
@@ -112,17 +117,19 @@
func (s *Server) RegisterCompatUnaryHandler(
procedure string,
+ method string,
srv interface{},
unary MethodHandler,
options ...HandlerOption,
) error {
hdl, ok := s.handlers[procedure]
if !ok {
- hdl = NewCompatUnaryHandler(procedure, srv, unary, options...)
+ hdl = NewCompatUnaryHandler(procedure, method, srv, unary, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
- implementation := generateCompatUnaryHandlerFunc(procedure, srv, unary, config.Interceptor)
+ implementation := generateCompatUnaryHandlerFunc(procedure, method, srv, unary, config.Interceptor)
hdl.processImplementation(getIdentifier(config.Group, config.Version), implementation)
}
@@ -140,6 +147,7 @@
if !ok {
hdl = NewCompatStreamHandler(procedure, srv, typ, streamFunc, options...)
s.handlers[procedure] = hdl
+ s.mux.Handle(procedure, hdl)
} else {
config := newHandlerConfig(procedure, options)
implementation := generateCompatStreamHandlerFunc(procedure, srv, streamFunc, config.Interceptor)
@@ -150,12 +158,8 @@
}
func (s *Server) Run() error {
- mux := http.NewServeMux()
- for procedure, hdl := range s.handlers {
- mux.Handle(procedure, hdl)
- }
// todo(DMwangnima): deal with TLS
- s.httpSrv.Handler = h2c.NewHandler(mux, &http2.Server{})
+ s.httpSrv.Handler = h2c.NewHandler(s.mux, &http2.Server{})
if err := s.httpSrv.ListenAndServe(); err != nil {
return err
@@ -173,6 +177,7 @@
func NewServer(addr string) *Server {
return &Server{
+ mux: http.NewServeMux(),
handlers: make(map[string]*Handler),
httpSrv: &http.Server{Addr: addr},
}
diff --git a/protocol/triple/triple_protocol/server_test.go b/protocol/triple/triple_protocol/server_test.go
new file mode 100644
index 0000000..ec962fe
--- /dev/null
+++ b/protocol/triple/triple_protocol/server_test.go
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple_protocol
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net/http"
+ "net/url"
+ "testing"
+)
+
+func TestServer_RegisterMuxHandle(t *testing.T) {
+ tests := []struct {
+ desc string
+ path string
+ registerFunc func(srv *Server, path string) error
+ }{
+ {
+ desc: "RegisterUnaryHandler_MuxHandle",
+ path: "/Unary",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterUnaryHandler(path, nil, nil)
+ },
+ },
+ {
+ desc: "RegisterClientStreamHandler_MuxHandle",
+ path: "/ClientStream",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterClientStreamHandler(path, nil)
+ },
+ },
+ {
+ desc: "RegisterServerStreamHandler_MuxHandle",
+ path: "/ServerStream",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterServerStreamHandler(path, nil, nil)
+ },
+ },
+ {
+ desc: "RegisterBidiStreamHandler_MuxHandle",
+ path: "/BidiStream",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterBidiStreamHandler(path, nil)
+ },
+ },
+ {
+ desc: "RegisterCompatUnaryHandler_MuxHandle",
+ path: "/CompatUnary",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterCompatUnaryHandler(path, "", nil, nil)
+ },
+ },
+ {
+ desc: "RegisterCompatStreamHandler_MuxHandle",
+ path: "/CompatStream",
+ registerFunc: func(srv *Server, path string) error {
+ return srv.RegisterCompatStreamHandler(path, nil, StreamTypeBidi, nil)
+ },
+ },
+ }
+
+ srv := NewServer("127.0.0.1:20000")
+ for _, test := range tests {
+ err := srv.RegisterUnaryHandler(test.path, nil, nil)
+ assert.Nil(t, err)
+ _, pattern := srv.mux.Handler(&http.Request{
+ URL: &url.URL{
+ Path: test.path,
+ },
+ })
+ assert.Equal(t, test.path, pattern)
+ }
+}
diff --git a/protocol/triple/triple_protocol/triple.go b/protocol/triple/triple_protocol/triple.go
index e56dd4b..68f711a 100644
--- a/protocol/triple/triple_protocol/triple.go
+++ b/protocol/triple/triple_protocol/triple.go
@@ -84,6 +84,7 @@
Receive(interface{}) error
RequestHeader() http.Header
+ ExportableHeader() http.Header
Send(interface{}) error
ResponseHeader() http.Header
diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go
index acaf10c..dd72625 100644
--- a/protocol/triple/triple_test.go
+++ b/protocol/triple/triple_test.go
@@ -54,13 +54,15 @@
)
const (
- triplePort = "21000"
- dubbo3Port = "21001"
- listenAddr = "0.0.0.0"
- localAddr = "127.0.0.1"
- name = "triple"
- group = "g1"
- version = "v1"
+ triplePort = "21000"
+ dubbo3Port = "21001"
+ listenAddr = "0.0.0.0"
+ localAddr = "127.0.0.1"
+ name = "triple"
+ group = "g1"
+ version = "v1"
+ customTripleInterfaceName = "apache.dubbo.org.triple"
+ customDubbo3InterfaceName = "apache.dubbo.org.dubbo3"
)
type tripleInvoker struct {
@@ -85,7 +87,7 @@
func (t *tripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
name := invocation.MethodName()
args := invocation.Arguments()
- // todo(DMwangnima): user map to represent Methods
+ // todo(DMwangnima): use map to represent Methods
for _, method := range t.info.Methods {
if method.Name == name {
res, err := method.MethodFunc(ctx, args, t.handler)
@@ -103,6 +105,8 @@
common.WithPath(interfaceName),
common.WithLocation(addr),
common.WithPort(triplePort),
+ common.WithProtocol(TRIPLE),
+ common.WithInterface(interfaceName),
)
url.SetParam(constant.GroupKey, group)
url.SetParam(constant.VersionKey, version)
@@ -121,7 +125,7 @@
func runOldTripleServer(interfaceName string, group string, version string, addr string, desc *grpc_go.ServiceDesc, svc common.RPCService) {
url := common.NewURLWithOptions(
// todo(DMwangnima): figure this out
- common.WithPath(desc.ServiceName),
+ common.WithPath(interfaceName),
common.WithLocation(addr),
common.WithPort(dubbo3Port),
common.WithProtocol(TRIPLE),
@@ -143,14 +147,14 @@
Build()).
Build())
config.SetProviderService(svc)
- common.ServiceMap.Register(desc.ServiceName, TRIPLE, group, version, svc)
+ common.ServiceMap.Register(interfaceName, TRIPLE, group, version, svc)
invoker := extension.GetProxyFactory("default").GetInvoker(url)
GetProtocol().(*TripleProtocol).exportForTest(invoker, nil)
}
func TestMain(m *testing.M) {
runTripleServer(
- greettriple.GreetServiceName,
+ customTripleInterfaceName,
"",
"",
listenAddr,
@@ -158,7 +162,7 @@
new(api.GreetTripleServer),
)
runTripleServer(
- greettriple.GreetServiceName,
+ customTripleInterfaceName,
group,
version,
listenAddr,
@@ -166,7 +170,7 @@
new(api.GreetTripleServerGroup1Version1),
)
runOldTripleServer(
- dubbo3_greet.GreetService_ServiceDesc.ServiceName,
+ customDubbo3InterfaceName,
"",
"",
listenAddr,
@@ -174,7 +178,7 @@
new(dubbo3_api.GreetDubbo3Server),
)
runOldTripleServer(
- dubbo3_greet.GreetService_ServiceDesc.ServiceName,
+ customDubbo3InterfaceName,
group,
version,
listenAddr,
@@ -431,46 +435,46 @@
}
t.Run("triple2triple", func(t *testing.T) {
- invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
+ invoker, err := tripleInvokerInit(localAddr, triplePort, customTripleInterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, "")
})
t.Run("triple2triple_Group1Version1", func(t *testing.T) {
- invoker, err := tripleInvokerInit(localAddr, triplePort, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
+ invoker, err := tripleInvokerInit(localAddr, triplePort, customTripleInterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, api.GroupVersionIdentifier)
})
t.Run("triple2dubbo3", func(t *testing.T) {
- invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
+ invoker, err := tripleInvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, "", "", greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, "")
})
t.Run("triple2dubbo3_Group1Version1", func(t *testing.T) {
- invoker, err := tripleInvokerInit(localAddr, dubbo3Port, greettriple.GreetService_ClientInfo.InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
+ invoker, err := tripleInvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, group, version, greettriple.GreetService_ClientInfo.MethodNames, &greettriple.GreetService_ClientInfo)
assert.Nil(t, err)
invokeTripleCodeFunc(t, invoker, dubbo3_api.GroupVersionIdentifier)
})
t.Run("dubbo32triple", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
- invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc)
+ invoker, err := dubbo3InvokerInit(localAddr, triplePort, customTripleInterfaceName, "", "", svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, "")
})
t.Run("dubbo32triple_Group1Version1", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
- invoker, err := dubbo3InvokerInit(localAddr, triplePort, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc)
+ invoker, err := dubbo3InvokerInit(localAddr, triplePort, customTripleInterfaceName, group, version, svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, api.GroupVersionIdentifier)
})
t.Run("dubbo32dubbo3", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
- invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, "", "", svc)
+ invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, "", "", svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, "")
})
t.Run("dubbo32dubbo3_Group1Version1", func(t *testing.T) {
svc := new(dubbo3_greet.GreetServiceClientImpl)
- invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, dubbo3_greet.GreetService_ServiceDesc.ServiceName, group, version, svc)
+ invoker, err := dubbo3InvokerInit(localAddr, dubbo3Port, customDubbo3InterfaceName, group, version, svc)
assert.Nil(t, err)
invokeDubbo3CodeFunc(t, invoker, svc, dubbo3_api.GroupVersionIdentifier)
})
diff --git a/server/action.go b/server/action.go
index cf62658..98a40ec 100644
--- a/server/action.go
+++ b/server/action.go
@@ -139,7 +139,7 @@
if svc.Interface == "" {
svc.Interface = info.InterfaceName
}
- svcOpts.Id = info.InterfaceName
+ svcOpts.Id = common.GetReference(svcOpts.rpcService)
svcOpts.info = info
}
// TODO: delay needExport