Ftr: enable filter and cluster when client consumer provider directly (#1181)
* URL directly call add filter and cluster
* update
* update
* add mockFilter
Co-authored-by: kezhan <kezhan@shizhuang-inc.com>
Co-authored-by: Xin.Zh <dragoncharlie@foxmail.com>
diff --git a/config/reference_config.go b/config/reference_config.go
index b04bdfa..fd7ef49 100644
--- a/config/reference_config.go
+++ b/config/reference_config.go
@@ -37,6 +37,7 @@
"dubbo.apache.org/dubbo-go/v3/common/extension"
"dubbo.apache.org/dubbo-go/v3/common/proxy"
"dubbo.apache.org/dubbo-go/v3/protocol"
+ "dubbo.apache.org/dubbo-go/v3/protocol/protocolwrapper"
)
// ReferenceConfig is the configuration of service consumer
@@ -134,11 +135,42 @@
if len(c.urls) == 1 {
c.invoker = extension.GetProtocol(c.urls[0].Protocol).Refer(c.urls[0])
+ // c.URL != "" is direct call
+ if c.URL != "" {
+ //filter
+ c.invoker = protocolwrapper.BuildInvokerChain(c.invoker, constant.REFERENCE_FILTER_KEY)
+
+ // cluster
+ invokers := make([]protocol.Invoker, 0, len(c.urls))
+ invokers = append(invokers, c.invoker)
+ // TODO(decouple from directory, config should not depend on directory module)
+ var hitClu string
+ // not a registry url, must be direct invoke.
+ hitClu = constant.FAILOVER_CLUSTER_NAME
+ if len(invokers) > 0 {
+ u := invokers[0].GetURL()
+ if nil != &u {
+ hitClu = u.GetParam(constant.CLUSTER_KEY, constant.ZONEAWARE_CLUSTER_NAME)
+ }
+ }
+
+ cluster := extension.GetCluster(hitClu)
+ // If 'zone-aware' policy select, the invoker wrap sequence would be:
+ // ZoneAwareClusterInvoker(StaticDirectory) ->
+ // FailoverClusterInvoker(RegistryDirectory, routing happens here) -> Invoker
+ c.invoker = cluster.Join(directory.NewStaticDirectory(invokers))
+ }
} else {
invokers := make([]protocol.Invoker, 0, len(c.urls))
var regURL *common.URL
for _, u := range c.urls {
- invokers = append(invokers, extension.GetProtocol(u.Protocol).Refer(u))
+ invoker := extension.GetProtocol(u.Protocol).Refer(u)
+ // c.URL != "" is direct call
+ if c.URL != "" {
+ //filter
+ invoker = protocolwrapper.BuildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
+ }
+ invokers = append(invokers, invoker)
if u.Protocol == constant.REGISTRY_PROTOCOL {
regURL = u
}
diff --git a/config/reference_config_test.go b/config/reference_config_test.go
index aaf9c46..9b5335a 100644
--- a/config/reference_config_test.go
+++ b/config/reference_config_test.go
@@ -18,6 +18,7 @@
package config
import (
+ "context"
"sync"
"testing"
)
@@ -31,6 +32,7 @@
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/common/extension"
+ "dubbo.apache.org/dubbo-go/v3/filter"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/registry"
)
@@ -193,7 +195,6 @@
doInitConsumer()
extension.SetProtocol("registry", GetProtocol)
extension.SetCluster(constant.ZONEAWARE_CLUSTER_NAME, cluster_impl.NewZoneAwareCluster)
-
for _, reference := range consumerConfig.References {
reference.Refer(nil)
assert.NotNil(t, reference.invoker)
@@ -234,6 +235,7 @@
func TestReferP2P(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
+ mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000"
@@ -248,6 +250,7 @@
func TestReferMultiP2P(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
+ mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;dubbo://127.0.0.2:20000"
@@ -263,6 +266,7 @@
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
+ mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"
@@ -291,6 +295,7 @@
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
+ mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"
@@ -308,6 +313,7 @@
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
+ mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"
@@ -333,7 +339,8 @@
return &mockRegistryProtocol{}
}
-type mockRegistryProtocol struct{}
+type mockRegistryProtocol struct {
+}
func (*mockRegistryProtocol) Refer(url *common.URL) protocol.Invoker {
return protocol.NewBaseInvoker(url)
@@ -375,3 +382,23 @@
func (p *mockRegistryProtocol) GetRegistries() []registry.Registry {
return []registry.Registry{&mockServiceDiscoveryRegistry{}}
}
+
+func mockFilter() {
+ consumerFiler := &mockShutdownFilter{}
+ extension.SetFilter(constant.CONSUMER_SHUTDOWN_FILTER, func() filter.Filter {
+ return consumerFiler
+ })
+}
+
+type mockShutdownFilter struct {
+}
+
+// Invoke adds the requests count and block the new requests if application is closing
+func (gf *mockShutdownFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+ return invoker.Invoke(ctx, invocation)
+}
+
+// OnResponse reduces the number of active processes then return the process result
+func (gf *mockShutdownFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+ return result
+}
diff --git a/protocol/protocolwrapper/protocol_filter_wrapper.go b/protocol/protocolwrapper/protocol_filter_wrapper.go
index a2be0c4..42b6905 100644
--- a/protocol/protocolwrapper/protocol_filter_wrapper.go
+++ b/protocol/protocolwrapper/protocol_filter_wrapper.go
@@ -50,7 +50,7 @@
if pfw.protocol == nil {
pfw.protocol = extension.GetProtocol(invoker.GetURL().Protocol)
}
- invoker = buildInvokerChain(invoker, constant.SERVICE_FILTER_KEY)
+ invoker = BuildInvokerChain(invoker, constant.SERVICE_FILTER_KEY)
return pfw.protocol.Export(invoker)
}
@@ -63,7 +63,7 @@
if invoker == nil {
return nil
}
- return buildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
+ return BuildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
}
// Destroy will destroy all invoker and exporter.
@@ -71,7 +71,7 @@
pfw.protocol.Destroy()
}
-func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
+func BuildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
filterName := invoker.GetURL().GetParam(key, "")
if filterName == "" {
return invoker