feat:sort triple logic (#2483)
diff --git a/protocol/triple/server.go b/protocol/triple/server.go
index 79465df..cc0aba2 100644
--- a/protocol/triple/server.go
+++ b/protocol/triple/server.go
@@ -49,11 +49,14 @@
// Server is TRIPLE server
type Server struct {
httpServer *http.Server
+ handler *http.ServeMux
}
// NewServer creates a new TRIPLE server
func NewServer() *Server {
- return &Server{}
+ return &Server{
+ handler: http.NewServeMux(),
+ }
}
// Start TRIPLE server
@@ -69,19 +72,6 @@
srv := &http.Server{
Addr: addr,
}
-
- maxServerRecvMsgSize := constant.DefaultMaxServerRecvMsgSize
- if recvMsgSize, convertErr := humanize.ParseBytes(URL.GetParam(constant.MaxServerRecvMsgSize, "")); convertErr == nil && recvMsgSize != 0 {
- maxServerRecvMsgSize = int(recvMsgSize)
- }
- hanOpts = append(hanOpts, tri.WithReadMaxBytes(maxServerRecvMsgSize))
-
- maxServerSendMsgSize := constant.DefaultMaxServerSendMsgSize
- if sendMsgSize, convertErr := humanize.ParseBytes(URL.GetParam(constant.MaxServerSendMsgSize, "")); err == convertErr && sendMsgSize != 0 {
- maxServerSendMsgSize = int(sendMsgSize)
- }
- hanOpts = append(hanOpts, tri.WithSendMaxBytes(maxServerSendMsgSize))
-
serialization := URL.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
switch serialization {
case constant.ProtobufSerialization:
@@ -89,7 +79,6 @@
default:
panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
}
-
// todo: implement interceptor
// If global trace instance was set, then server tracer instance
// can be get. If not, will return NoopTracer.
@@ -116,14 +105,13 @@
// logger.Infof("Triple Server initialized the TLSConfig configuration")
//}
//srv.TLSConfig = cfg
-
- // todo:// open tracing
- hanOpts = append(hanOpts, tri.WithInterceptors())
// todo:// move tls config to handleService
+
+ hanOpts = getHanOpts(URL)
s.httpServer = srv
go func() {
- mux := http.NewServeMux()
+ mux := s.handler
if info != nil {
handleServiceWithInfo(invoker, info, mux, hanOpts...)
} else {
@@ -144,6 +132,48 @@
}()
}
+// RefreshService refreshes Triple Service
+func (s *Server) RefreshService(invoker protocol.Invoker, info *server.ServiceInfo) {
+ var (
+ URL *common.URL
+ hanOpts []tri.HandlerOption
+ )
+ URL = invoker.GetURL()
+ serialization := URL.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
+ switch serialization {
+ case constant.ProtobufSerialization:
+ case constant.JSONSerialization:
+ default:
+ panic(fmt.Sprintf("Unsupported serialization: %s", serialization))
+ }
+ hanOpts = getHanOpts(URL)
+ mux := s.handler
+ if info != nil {
+ handleServiceWithInfo(invoker, info, mux, hanOpts...)
+ } else {
+ compatHandleService(mux)
+ }
+}
+
+func getHanOpts(url *common.URL) (hanOpts []tri.HandlerOption) {
+ var err error
+ maxServerRecvMsgSize := constant.DefaultMaxServerRecvMsgSize
+ if recvMsgSize, convertErr := humanize.ParseBytes(url.GetParam(constant.MaxServerRecvMsgSize, "")); convertErr == nil && recvMsgSize != 0 {
+ maxServerRecvMsgSize = int(recvMsgSize)
+ }
+ hanOpts = append(hanOpts, tri.WithReadMaxBytes(maxServerRecvMsgSize))
+
+ maxServerSendMsgSize := constant.DefaultMaxServerSendMsgSize
+ if sendMsgSize, convertErr := humanize.ParseBytes(url.GetParam(constant.MaxServerSendMsgSize, "")); err == convertErr && sendMsgSize != 0 {
+ maxServerSendMsgSize = int(sendMsgSize)
+ }
+ hanOpts = append(hanOpts, tri.WithSendMaxBytes(maxServerSendMsgSize))
+
+ // todo:// open tracing
+ hanOpts = append(hanOpts, tri.WithInterceptors())
+ return hanOpts
+}
+
// getSyncMapLen gets sync map len
func getSyncMapLen(m *sync.Map) int {
length := 0
diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go
index cb64aa4..f5d523a 100644
--- a/protocol/triple/triple.go
+++ b/protocol/triple/triple.go
@@ -88,6 +88,7 @@
defer tp.serverLock.Unlock()
if _, ok := tp.serverMap[url.Location]; ok {
+ tp.serverMap[url.Location].RefreshService(invoker, info)
return
}
diff --git a/server/server.go b/server/server.go
index 2786cb6..e8add11 100644
--- a/server/server.go
+++ b/server/server.go
@@ -20,6 +20,7 @@
import (
"context"
"fmt"
+ "sync"
)
import (
@@ -38,6 +39,8 @@
info *ServiceInfo
cfg *ServerOptions
+
+ svcOptsMap sync.Map
}
// ServiceInfo is meta info of a service
@@ -145,14 +148,26 @@
return err
}
newSvcOpts.Implement(handler)
- if err := newSvcOpts.ExportWithInfo(info); err != nil {
- return err
- }
+ s.svcOptsMap.Store(newSvcOpts, info)
return nil
}
+func (s *Server) exportServices() (err error) {
+ s.svcOptsMap.Range(func(newSvcOpts, info interface{}) bool {
+ err = newSvcOpts.(*ServiceOptions).ExportWithInfo(info.(*ServiceInfo))
+ if err != nil {
+ return false
+ }
+ return true
+ })
+ return err
+}
+
func (s *Server) Serve() error {
+ if err := s.exportServices(); err != nil {
+ return err
+ }
metadata.ExportMetadataService()
registry_exposed.RegisterServiceInstance(s.cfg.Application.Name, s.cfg.Application.Tag, s.cfg.Application.MetadataType)
select {}