[ISSUE #314] support name server domain (#315)
* support name server domain
diff --git a/consumer/consumer.go b/consumer/consumer.go
index 6ebe9ff..45056b8 100644
--- a/consumer/consumer.go
+++ b/consumer/consumer.go
@@ -271,7 +271,9 @@
}
func (dc *defaultConsumer) start() error {
-
+ if len(dc.option.NameServerAddrs) == 0 {
+ dc.namesrv.UpdateNameServerAddress(dc.option.NameServerDomain, dc.option.InstanceName)
+ }
if dc.model == Clustering {
// set retry topic
retryTopic := internal.GetRetryTopic(dc.consumerGroup)
diff --git a/consumer/option.go b/consumer/option.go
index 323753e..6e7f4d2 100644
--- a/consumer/option.go
+++ b/consumer/option.go
@@ -186,6 +186,13 @@
}
}
+// WithNameServerDomain set NameServer domain
+func WithNameServerDomain(nameServerUrl string) Option {
+ return func(opts *consumerOptions) {
+ opts.NameServerDomain = nameServerUrl
+ }
+}
+
// WithNamespace set the namespace of consumer
func WithNamespace(namespace string) Option {
return func(opts *consumerOptions) {
diff --git a/internal/client.go b/internal/client.go
index c952385..17c089b 100644
--- a/internal/client.go
+++ b/internal/client.go
@@ -98,6 +98,7 @@
type ClientOptions struct {
GroupName string
NameServerAddrs primitive.NamesrvAddr
+ NameServerDomain string
Namesrv *namesrvs
ClientIP string
InstanceName string
@@ -259,8 +260,26 @@
if !c.option.Credentials.IsEmpty() {
c.remoteClient.RegisterInterceptor(remote.ACLInterceptor(c.option.Credentials))
}
- // TODO fetchNameServerAddr
- go func() {}()
+ // fetchNameServerAddr
+ if len(c.option.NameServerAddrs) == 0 {
+ go func() {
+ // delay
+ ticker := time.NewTicker(60 * 2 * time.Second)
+ defer ticker.Stop()
+ time.Sleep(50 * time.Millisecond)
+ for {
+ select {
+ case <-ticker.C:
+ c.namesrvs.UpdateNameServerAddress(c.option.NameServerDomain, c.option.InstanceName)
+ case <-c.done:
+ rlog.Info("The RMQClient stopping update name server domain info.", map[string]interface{}{
+ "clientID": c.ClientID(),
+ })
+ return
+ }
+ }
+ }()
+ }
// schedule update route info
go func() {
diff --git a/internal/mock_namesrv.go b/internal/mock_namesrv.go
index 19a2181..365e784 100644
--- a/internal/mock_namesrv.go
+++ b/internal/mock_namesrv.go
@@ -74,6 +74,18 @@
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "cleanOfflineBroker", reflect.TypeOf((*MockNamesrvs)(nil).cleanOfflineBroker))
}
+// UpdateNameServerAddress mocks base method
+func (m *MockNamesrvs) UpdateNameServerAddress(nameServer, instanceName string) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "UpdateNameServerAddress", nameServer, instanceName)
+}
+
+// UpdateNameServerAddress indicates an expected call of UpdateNameServerAddress
+func (mr *MockNamesrvsMockRecorder) UpdateNameServerAddress(nameServer, instanceName string) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNameServerAddress", reflect.TypeOf((*MockNamesrvs)(nil).UpdateNameServerAddress), nameServer, instanceName)
+}
+
// UpdateTopicRouteInfo mocks base method
func (m *MockNamesrvs) UpdateTopicRouteInfo(topic string) *TopicRouteData {
m.ctrl.T.Helper()
diff --git a/internal/namesrv.go b/internal/namesrv.go
index b9f1744..27d38e8 100644
--- a/internal/namesrv.go
+++ b/internal/namesrv.go
@@ -19,12 +19,23 @@
import (
"errors"
+ "fmt"
+ "github.com/apache/rocketmq-client-go/internal/remote"
+ "github.com/apache/rocketmq-client-go/primitive"
+ "github.com/apache/rocketmq-client-go/rlog"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "os/user"
+ "path"
"regexp"
"strings"
"sync"
+ "time"
+)
- "github.com/apache/rocketmq-client-go/internal/remote"
- "github.com/apache/rocketmq-client-go/primitive"
+const (
+ DEFAULT_NAMESRV_ADDR = "http://jmenv.tbsite.net:8080/rocketmq/nsaddr"
)
var (
@@ -37,6 +48,8 @@
//go:generate mockgen -source namesrv.go -destination mock_namesrv.go -self_package github.com/apache/rocketmq-client-go/internal --package internal Namesrvs
type Namesrvs interface {
+ UpdateNameServerAddress(nameServerDomain, instanceName string)
+
AddBroker(routeData *TopicRouteData)
cleanOfflineBroker()
@@ -125,3 +138,100 @@
func (s *namesrvs) AddrList() []string {
return s.srvs
}
+
+func getSnapshotFilePath(instanceName string) string {
+ homeDir := ""
+ if usr, err := user.Current(); err == nil {
+ homeDir = usr.HomeDir
+ } else {
+ rlog.Error("name server domain, can't get user home directory", map[string]interface{}{
+ "err": err,
+ })
+ }
+ storePath := path.Join(homeDir, "/logs/rocketmq-go/snapshot")
+ if _, err := os.Stat(storePath); os.IsNotExist(err) {
+ if err = os.MkdirAll(storePath, 0755); err != nil {
+ rlog.Fatal("can't create name server snapshot directory", map[string]interface{}{
+ "path": storePath,
+ "err": err,
+ })
+ }
+ }
+ filePath := path.Join(storePath, fmt.Sprintf("nameserver_addr-%s", instanceName))
+ return filePath
+}
+
+// UpdateNameServerAddress will update srvs.
+// docs: https://rocketmq.apache.org/docs/best-practice-namesvr/
+func (s *namesrvs) UpdateNameServerAddress(nameServerDomain, instanceName string) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if nameServerDomain == "" {
+ // try to get from environment variable
+ if v := os.Getenv("NAMESRV_ADDR"); v != "" {
+ s.srvs = strings.Split(v, ";")
+ return
+ }
+ // use default domain
+ nameServerDomain = DEFAULT_NAMESRV_ADDR
+ }
+
+ client := http.Client{Timeout: 10 * time.Second}
+ resp, err := client.Get(nameServerDomain)
+ if err == nil {
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ if err == nil {
+ oldBodyStr := strings.Join(s.srvs, ";")
+ bodyStr := string(body)
+ if bodyStr != "" && oldBodyStr != bodyStr {
+ s.srvs = strings.Split(string(body), ";")
+
+ rlog.Info("name server address changed", map[string]interface{}{
+ "old": oldBodyStr,
+ "new": bodyStr,
+ })
+ // save to local snapshot
+ filePath := getSnapshotFilePath(instanceName)
+ if err := ioutil.WriteFile(filePath, body, 0644); err == nil {
+ rlog.Info("name server snapshot save successfully", map[string]interface{}{
+ "filePath": filePath,
+ })
+ } else {
+ rlog.Error("name server snapshot save failed", map[string]interface{}{
+ "filePath": filePath,
+ "err": err,
+ })
+ }
+ }
+ rlog.Info("name server http fetch successfully", map[string]interface{}{
+ "addrs": bodyStr,
+ })
+ return
+ } else {
+ rlog.Error("name server http fetch failed", map[string]interface{}{
+ "NameServerDomain": nameServerDomain,
+ "err": err,
+ })
+ }
+ }
+
+ // load local snapshot if need when name server domain request failed
+ if len(s.srvs) == 0 {
+ filePath := getSnapshotFilePath(instanceName)
+ if _, err := os.Stat(filePath); !os.IsNotExist(err) {
+ if bs, err := ioutil.ReadFile(filePath); err == nil {
+ rlog.Info("load the name server snapshot local file", map[string]interface{}{
+ "filePath": filePath,
+ })
+ s.srvs = strings.Split(string(bs), ";")
+ return
+ }
+ } else {
+ rlog.Warning("name server snapshot local file not exists", map[string]interface{}{
+ "filePath": filePath,
+ })
+ }
+ }
+}
diff --git a/internal/namesrv_test.go b/internal/namesrv_test.go
index b6db2a0..ede14fc 100644
--- a/internal/namesrv_test.go
+++ b/internal/namesrv_test.go
@@ -18,6 +18,12 @@
package internal
import (
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "strings"
"sync"
"testing"
@@ -65,3 +71,177 @@
So(IP2, ShouldEqual, ns.srvs[index2])
})
}
+
+func TestUpdateNameServerAddress(t *testing.T) {
+ Convey("Test UpdateNameServerAddress method", t, func() {
+ srvs := []string{
+ "192.168.100.1",
+ "192.168.100.2",
+ "192.168.100.3",
+ "192.168.100.4",
+ "192.168.100.5",
+ }
+ http.HandleFunc("/nameserver/addrs", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, strings.Join(srvs, ";"))
+ })
+ server := &http.Server{Addr: ":0", Handler: nil}
+ listener, _ := net.Listen("tcp", ":0")
+ go server.Serve(listener)
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ nameServerDommain := fmt.Sprintf("http://127.0.0.1:%d/nameserver/addrs", port)
+ fmt.Println("temporary name server domain: ", nameServerDommain)
+
+ ns := &namesrvs{
+ srvs: []string{},
+ lock: new(sync.Mutex),
+ }
+ ns.UpdateNameServerAddress(nameServerDommain, "DEFAULT")
+
+ index1 := ns.index
+ IP1 := ns.getNameServerAddress()
+
+ index2 := ns.index
+ IP2 := ns.getNameServerAddress()
+
+ So(index1+1, ShouldEqual, index2)
+ So(IP1, ShouldEqual, srvs[index1])
+ So(IP2, ShouldEqual, srvs[index2])
+ })
+}
+
+func TestUpdateNameServerAddressSaveLocalSnapshot(t *testing.T) {
+ Convey("Test UpdateNameServerAddress Save Local Snapshot", t, func() {
+ srvs := []string{
+ "192.168.100.1",
+ "192.168.100.2",
+ "192.168.100.3",
+ "192.168.100.4",
+ "192.168.100.5",
+ }
+ http.HandleFunc("/nameserver/addrs2", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, strings.Join(srvs, ";"))
+ })
+ server := &http.Server{Addr: ":0", Handler: nil}
+ listener, _ := net.Listen("tcp", ":0")
+ go server.Serve(listener)
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ nameServerDommain := fmt.Sprintf("http://127.0.0.1:%d/nameserver/addrs2", port)
+ fmt.Println("temporary name server domain: ", nameServerDommain)
+
+ ns := &namesrvs{
+ srvs: []string{},
+ lock: new(sync.Mutex),
+ }
+ ns.UpdateNameServerAddress(nameServerDommain, "DEFAULT")
+ // check snapshot saved
+ filePath := getSnapshotFilePath("DEFAULT")
+ body := strings.Join(srvs, ";")
+ bs, _ := ioutil.ReadFile(filePath)
+ So(string(bs), ShouldEqual, body)
+ })
+}
+
+func TestUpdateNameServerAddressUseEnv(t *testing.T) {
+ Convey("Test UpdateNameServerAddress Use Env", t, func() {
+ srvs := []string{
+ "192.168.100.1",
+ "192.168.100.2",
+ "192.168.100.3",
+ "192.168.100.4",
+ "192.168.100.5",
+ }
+
+ ns := &namesrvs{
+ srvs: []string{},
+ lock: new(sync.Mutex),
+ }
+ os.Setenv("NAMESRV_ADDR", strings.Join(srvs, ";"))
+ ns.UpdateNameServerAddress("", "DEFAULT")
+
+ index1 := ns.index
+ IP1 := ns.getNameServerAddress()
+
+ index2 := ns.index
+ IP2 := ns.getNameServerAddress()
+
+ So(index1+1, ShouldEqual, index2)
+ So(IP1, ShouldEqual, srvs[index1])
+ So(IP2, ShouldEqual, srvs[index2])
+ })
+}
+
+func TestUpdateNameServerAddressUseSnapshotFile(t *testing.T) {
+ Convey("Test UpdateNameServerAddress Use Local Snapshot", t, func() {
+ srvs := []string{
+ "192.168.100.1",
+ "192.168.100.2",
+ "192.168.100.3",
+ "192.168.100.4",
+ "192.168.100.5",
+ }
+
+ ns := &namesrvs{
+ srvs: []string{},
+ lock: new(sync.Mutex),
+ }
+
+ os.Setenv("NAMESRV_ADDR", "") // clear env
+ // setup local snapshot file
+ filePath := getSnapshotFilePath("DEFAULT")
+ body := strings.Join(srvs, ";")
+ _ = ioutil.WriteFile(filePath, []byte(body), 0644)
+
+ ns.UpdateNameServerAddress("http://127.0.0.1:80/error/nsaddrs", "DEFAULT")
+
+ index1 := ns.index
+ IP1 := ns.getNameServerAddress()
+
+ index2 := ns.index
+ IP2 := ns.getNameServerAddress()
+
+ So(index1+1, ShouldEqual, index2)
+ So(IP1, ShouldEqual, srvs[index1])
+ So(IP2, ShouldEqual, srvs[index2])
+ })
+}
+
+func TestUpdateNameServerAddressLoadSnapshotFileOnce(t *testing.T) {
+ Convey("Test UpdateNameServerAddress Load Local Snapshot Once", t, func() {
+ srvs := []string{
+ "192.168.100.1",
+ "192.168.100.2",
+ "192.168.100.3",
+ "192.168.100.4",
+ "192.168.100.5",
+ }
+
+ ns := &namesrvs{
+ srvs: []string{},
+ lock: new(sync.Mutex),
+ }
+
+ os.Setenv("NAMESRV_ADDR", "") // clear env
+ // setup local snapshot file
+ filePath := getSnapshotFilePath("DEFAULT")
+ body := strings.Join(srvs, ";")
+ _ = ioutil.WriteFile(filePath, []byte(body), 0644)
+ // load local snapshot file first time
+ ns.UpdateNameServerAddress("http://127.0.0.1:80/error/nsaddrs", "DEFAULT")
+
+ // change the local snapshot file to check load once
+ _ = ioutil.WriteFile(filePath, []byte("127.0.0.1;127.0.0.2"), 0644)
+ ns.UpdateNameServerAddress("http://127.0.0.1:80/error/nsaddrs", "DEFAULT")
+
+ index1 := ns.index
+ IP1 := ns.getNameServerAddress()
+
+ index2 := ns.index
+ IP2 := ns.getNameServerAddress()
+
+ So(index1+1, ShouldEqual, index2)
+ So(IP1, ShouldEqual, srvs[index1])
+ So(IP2, ShouldEqual, srvs[index2])
+ })
+}