fix: modify the logic of pick shards (#259)
## Rationale
We found that according to the current way of pick shards, the subtables
of the partition table will be concentrated on some nodes, causing hot
spots, so some modifications need to be made to avoid this phenomenon.
## Detailed Changes
* Modify the logic of pick shard, when the number of tables is the same,
sort according to the size of ShardID.
## Test Plan
Pass all unit tests and integration test.
diff --git a/server/coordinator/procedure/test/common.go b/server/coordinator/procedure/test/common.go
index 7d35535..5dcbc55 100644
--- a/server/coordinator/procedure/test/common.go
+++ b/server/coordinator/procedure/test/common.go
@@ -15,6 +15,7 @@
"github.com/CeresDB/ceresmeta/server/cluster/metadata"
"github.com/CeresDB/ceresmeta/server/coordinator/eventdispatch"
"github.com/CeresDB/ceresmeta/server/coordinator/procedure"
+ "github.com/CeresDB/ceresmeta/server/coordinator/scheduler/nodepicker"
"github.com/CeresDB/ceresmeta/server/etcdutil"
"github.com/CeresDB/ceresmeta/server/storage"
"github.com/stretchr/testify/require"
@@ -135,6 +136,51 @@
return c
}
+func InitEmptyClusterWithConfig(ctx context.Context, t *testing.T, shardNumber int, nodeNumber int) *cluster.Cluster {
+ re := require.New(t)
+
+ _, client, _ := etcdutil.PrepareEtcdServerAndClient(t)
+ clusterStorage := storage.NewStorageWithEtcdBackend(client, TestRootPath, storage.Options{
+ MaxScanLimit: 100, MinScanLimit: 10, MaxOpsPerTxn: 32,
+ })
+
+ logger := zap.NewNop()
+
+ clusterMetadata := metadata.NewClusterMetadata(logger, storage.Cluster{
+ ID: 0,
+ Name: ClusterName,
+ MinNodeCount: uint32(nodeNumber),
+ ShardTotal: uint32(shardNumber),
+ EnableSchedule: DefaultSchedulerOperator,
+ TopologyType: DefaultTopologyType,
+ ProcedureExecutingBatchSize: DefaultProcedureExecutingBatchSize,
+ CreatedAt: 0,
+ }, clusterStorage, client, TestRootPath, DefaultIDAllocatorStep)
+
+ err := clusterMetadata.Init(ctx)
+ re.NoError(err)
+
+ err = clusterMetadata.Load(ctx)
+ re.NoError(err)
+
+ c, err := cluster.NewCluster(logger, clusterMetadata, client, TestRootPath)
+ re.NoError(err)
+
+ _, _, err = c.GetMetadata().GetOrCreateSchema(ctx, TestSchemaName)
+ re.NoError(err)
+
+ lastTouchTime := time.Now().UnixMilli()
+ for i := 0; i < nodeNumber; i++ {
+ err = c.GetMetadata().RegisterNode(ctx, metadata.RegisteredNode{
+ Node: storage.Node{Name: fmt.Sprintf("node%d", i), LastTouchTime: uint64(lastTouchTime)},
+ ShardInfos: nil,
+ })
+ re.NoError(err)
+ }
+
+ return c
+}
+
// InitPrepareCluster will return a cluster that has created shards and nodes, and cluster state is prepare.
func InitPrepareCluster(ctx context.Context, t *testing.T) *cluster.Cluster {
re := require.New(t)
@@ -167,3 +213,30 @@
return c
}
+
+func InitStableClusterWithConfig(ctx context.Context, t *testing.T, nodeNumber int, shardNumber int) *cluster.Cluster {
+ re := require.New(t)
+ c := InitEmptyClusterWithConfig(ctx, t, shardNumber, nodeNumber)
+ snapshot := c.GetMetadata().GetClusterSnapshot()
+ shardNodes := make([]storage.ShardNode, 0, DefaultShardTotal)
+ nodePicker := nodepicker.NewConsistentUniformHashNodePicker(zap.NewNop())
+ var unAssignedShardIDs []storage.ShardID
+ for i := 0; i < shardNumber; i++ {
+ unAssignedShardIDs = append(unAssignedShardIDs, storage.ShardID(i))
+ }
+ shardNodeMapping, err := nodePicker.PickNode(ctx, nodepicker.Config{NumTotalShards: uint32(shardNumber)}, unAssignedShardIDs, snapshot.RegisteredNodes)
+ re.NoError(err)
+
+ for shardID, node := range shardNodeMapping {
+ shardNodes = append(shardNodes, storage.ShardNode{
+ ID: shardID,
+ ShardRole: storage.ShardRoleLeader,
+ NodeName: node.Node.Name,
+ })
+ }
+
+ err = c.GetMetadata().UpdateClusterView(ctx, storage.ClusterStateStable, shardNodes)
+ re.NoError(err)
+
+ return c
+}
diff --git a/server/coordinator/shard_picker.go b/server/coordinator/shard_picker.go
index 1bc61ef..353c847 100644
--- a/server/coordinator/shard_picker.go
+++ b/server/coordinator/shard_picker.go
@@ -40,10 +40,16 @@
sortedShardsByTableCount = append(sortedShardsByTableCount, shardNode.ID)
}
- // sort shard by table number,
+ // Sort shard by table number,
// the shard with the smallest number of tables is at the front of the array.
sort.SliceStable(sortedShardsByTableCount, func(i, j int) bool {
- return len(snapshot.Topology.ShardViewsMapping[sortedShardsByTableCount[i]].TableIDs) < len(snapshot.Topology.ShardViewsMapping[sortedShardsByTableCount[j]].TableIDs)
+ shardView1 := snapshot.Topology.ShardViewsMapping[sortedShardsByTableCount[i]]
+ shardView2 := snapshot.Topology.ShardViewsMapping[sortedShardsByTableCount[j]]
+ // When the number of tables is the same, sort according to the size of ShardID.
+ if len(shardView1.TableIDs) == len(shardView2.TableIDs) {
+ return shardView1.ShardID < shardView2.ShardID
+ }
+ return len(shardView1.TableIDs) < len(shardView2.TableIDs)
})
result := make([]storage.ShardNode, 0, expectShardNum)
diff --git a/server/coordinator/shard_picker_test.go b/server/coordinator/shard_picker_test.go
index 4c21eab..52ad84b 100644
--- a/server/coordinator/shard_picker_test.go
+++ b/server/coordinator/shard_picker_test.go
@@ -4,6 +4,7 @@
import (
"context"
+ "sort"
"testing"
"github.com/CeresDB/ceresmeta/server/cluster/metadata"
@@ -71,4 +72,34 @@
for _, shardNode := range shardNodes {
re.NotEqual(shardNode.ID, 1)
}
+
+ checkPartitionTable(ctx, shardPicker, t, 50, 256, 20, 2)
+ checkPartitionTable(ctx, shardPicker, t, 50, 256, 30, 2)
+ checkPartitionTable(ctx, shardPicker, t, 50, 256, 40, 2)
+ checkPartitionTable(ctx, shardPicker, t, 50, 256, 50, 2)
+}
+
+func checkPartitionTable(ctx context.Context, shardPicker coordinator.ShardPicker, t *testing.T, nodeNumber int, shardNumber int, subTableNumber int, maxDifference int) {
+ re := require.New(t)
+
+ var shardNodes []storage.ShardNode
+
+ c := test.InitStableClusterWithConfig(ctx, t, nodeNumber, shardNumber)
+ shardNodes, err := shardPicker.PickShards(ctx, c.GetMetadata().GetClusterSnapshot(), subTableNumber)
+ re.NoError(err)
+
+ nodeTableCountMapping := make(map[string]int, 0)
+ for _, shardNode := range shardNodes {
+ nodeTableCountMapping[shardNode.NodeName]++
+ }
+
+ // Ensure the difference in the number of tables is no greater than maxDifference
+ var nodeTableNumberSlice []int
+ for _, tableNumber := range nodeTableCountMapping {
+ nodeTableNumberSlice = append(nodeTableNumberSlice, tableNumber)
+ }
+ sort.Ints(nodeTableNumberSlice)
+ minTableNumber := nodeTableNumberSlice[0]
+ maxTableNumber := nodeTableNumberSlice[len(nodeTableNumberSlice)-1]
+ re.LessOrEqual(maxTableNumber-minTableNumber, maxDifference)
}