blob: af8993f05847c1b72f67c6a4461c90e2d6585f6c [file] [log] [blame]
// Copyright 2022 CeresDB Project Authors. Licensed under Apache-2.0.
package procedure
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"github.com/CeresDB/ceresmeta/server/cluster"
"github.com/pkg/errors"
)
// ShardPicker is used to pick up the shards suitable for scheduling in the cluster.
type ShardPicker interface {
PickShards(ctx context.Context, clusterName string, expectShardNum int) ([]cluster.ShardNodeWithVersion, error)
}
// RandomShardPicker randomly pick up shards that are not on the same node in the current cluster.
type RandomShardPicker struct {
clusterManager cluster.Manager
}
func NewRandomShardPicker(manager cluster.Manager) ShardPicker {
return &RandomShardPicker{
clusterManager: manager,
}
}
// PickShards will pick a specified number of shards as expectShardNum.
func (p *RandomShardPicker) PickShards(ctx context.Context, clusterName string, expectShardNum int) ([]cluster.ShardNodeWithVersion, error) {
getNodeShardResult, err := p.clusterManager.GetNodeShards(ctx, clusterName)
if err != nil {
return []cluster.ShardNodeWithVersion{}, errors.WithMessage(err, "get node shards")
}
nodeShardsMapping := make(map[string][]cluster.ShardNodeWithVersion, 0)
var nodeNames []string
for _, nodeShard := range getNodeShardResult.NodeShards {
_, exists := nodeShardsMapping[nodeShard.ShardNode.NodeName]
if !exists {
nodeShards := []cluster.ShardNodeWithVersion{}
nodeNames = append(nodeNames, nodeShard.ShardNode.NodeName)
nodeShardsMapping[nodeShard.ShardNode.NodeName] = nodeShards
}
nodeShardsMapping[nodeShard.ShardNode.NodeName] = append(nodeShardsMapping[nodeShard.ShardNode.NodeName], nodeShard)
}
if len(nodeShardsMapping) < expectShardNum {
return []cluster.ShardNodeWithVersion{}, errors.WithMessage(ErrNodeNumberNotEnough, fmt.Sprintf("number of nodes is:%d, expecet number of shards is:%d", len(nodeShardsMapping), expectShardNum))
}
var selectNodeNames []string
for i := 0; i < expectShardNum; i++ {
selectNodeIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nodeNames))))
if err != nil {
return []cluster.ShardNodeWithVersion{}, errors.WithMessage(err, "generate random node index")
}
selectNodeNames = append(selectNodeNames, nodeNames[selectNodeIndex.Int64()])
nodeNames[selectNodeIndex.Int64()] = nodeNames[len(nodeNames)-1]
nodeNames = nodeNames[:len(nodeNames)-1]
}
result := []cluster.ShardNodeWithVersion{}
for _, nodeName := range selectNodeNames {
nodeShards := nodeShardsMapping[nodeName]
selectShardIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nodeShards))))
if err != nil {
return []cluster.ShardNodeWithVersion{}, errors.WithMessage(err, "generate random node index")
}
result = append(result, nodeShards[selectShardIndex.Int64()])
}
return result, nil
}