blob: 9056403d5a0870ddfb1ae9f4f019bd2a773a3dee [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version
* 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package org.apache.storm.grouping;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.storm.Config;
import org.apache.storm.generated.GlobalStreamId;
import org.apache.storm.generated.NodeInfo;
import org.apache.storm.networktopography.DNSToSwitchMapping;
import org.apache.storm.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.storm.shade.com.google.common.collect.Sets;
import org.apache.storm.task.WorkerTopologyContext;
import org.apache.storm.utils.ObjectReader;
import org.apache.storm.utils.ReflectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class LoadAwareShuffleGrouping implements LoadAwareCustomStreamGrouping, Serializable {
private static final int MAX_WEIGHT = 100;
private static final Logger LOG = LoggerFactory.getLogger(LoadAwareShuffleGrouping.class);
private final Map<Integer, IndexAndWeights> orig = new HashMap<>();
@VisibleForTesting
List<Integer>[] rets;
@VisibleForTesting
volatile int[] choices;
private int capacity;
private Random random;
private volatile int[] prepareChoices;
private AtomicInteger current;
private Scope currentScope;
private NodeInfo sourceNodeInfo;
private List<Integer> targetTasks;
private AtomicReference<Map<Integer, NodeInfo>> taskToNodePort;
private Map<String, Object> conf;
private DNSToSwitchMapping dnsToSwitchMapping;
private Map<Scope, List<Integer>> localityGroup;
private double higherBound;
private double lowerBound;
@Override
public void prepare(WorkerTopologyContext context, GlobalStreamId stream, List<Integer> targetTasks) {
random = new Random();
sourceNodeInfo = new NodeInfo(context.getThisWorkerHost(), Sets.newHashSet((long) context.getThisWorkerPort()));
taskToNodePort = context.getTaskToNodePort();
this.targetTasks = targetTasks;
capacity = targetTasks.size() == 1 ? 1 : Math.max(1000, targetTasks.size() * 5);
conf = context.getConf();
dnsToSwitchMapping = ReflectionUtils.newInstance((String) conf.get(Config.STORM_NETWORK_TOPOGRAPHY_PLUGIN));
localityGroup = new HashMap<>();
currentScope = Scope.WORKER_LOCAL;
higherBound = ObjectReader.getDouble(conf.get(Config.TOPOLOGY_LOCALITYAWARE_HIGHER_BOUND));
lowerBound = ObjectReader.getDouble(conf.get(Config.TOPOLOGY_LOCALITYAWARE_LOWER_BOUND));
rets = (List<Integer>[]) new List<?>[targetTasks.size()];
int i = 0;
for (int target : targetTasks) {
rets[i] = Arrays.asList(target);
orig.put(target, new IndexAndWeights(i));
i++;
}
// can't leave choices to be empty, so initiate it similar as ShuffleGrouping
choices = new int[capacity];
current = new AtomicInteger(0);
// allocate another array to be switched
prepareChoices = new int[capacity];
updateRing(null);
}
@Override
public List<Integer> chooseTasks(int taskId, List<Object> values) {
int rightNow;
while (true) {
rightNow = current.incrementAndGet();
if (rightNow < capacity) {
return rets[choices[rightNow]];
} else if (rightNow == capacity) {
current.set(0);
return rets[choices[0]];
}
//race condition with another thread, and we lost
// try again
}
}
@Override
public void refreshLoad(LoadMapping loadMapping) {
updateRing(loadMapping);
}
private void refreshLocalityGroup() {
Map<Integer, NodeInfo> cachedTaskToNodePort = taskToNodePort.get();
Map<String, String> hostToRack = getHostToRackMapping(cachedTaskToNodePort);
localityGroup.values().stream().forEach(v -> v.clear());
for (int target : targetTasks) {
Scope scope = calculateScope(cachedTaskToNodePort, hostToRack, target);
if (!localityGroup.containsKey(scope)) {
localityGroup.put(scope, new ArrayList<>());
}
localityGroup.get(scope).add(target);
}
}
private List<Integer> getTargetsInScope(Scope scope) {
List<Integer> rets = new ArrayList<>();
List<Integer> targetInScope = localityGroup.get(scope);
if (null != targetInScope) {
rets.addAll(targetInScope);
}
Scope downgradeScope = Scope.downgrade(scope);
if (downgradeScope != scope) {
rets.addAll(getTargetsInScope(downgradeScope));
}
return rets;
}
private Scope transition(LoadMapping load) {
List<Integer> targetInScope = getTargetsInScope(currentScope);
if (targetInScope.isEmpty()) {
Scope upScope = Scope.upgrade(currentScope);
if (upScope == currentScope) {
throw new RuntimeException("The current scope " + currentScope + " has no target tasks.");
}
currentScope = upScope;
return transition(load);
}
if (null == load) {
return currentScope;
}
double avg = targetInScope.stream().mapToDouble((key) -> load.get(key)).average().getAsDouble();
Scope nextScope;
if (avg < lowerBound) {
nextScope = Scope.downgrade(currentScope);
if (getTargetsInScope(nextScope).isEmpty()) {
nextScope = currentScope;
}
} else if (avg > higherBound) {
nextScope = Scope.upgrade(currentScope);
} else {
nextScope = currentScope;
}
return nextScope;
}
private synchronized void updateRing(LoadMapping load) {
refreshLocalityGroup();
Scope prevScope = currentScope;
currentScope = transition(load);
if (currentScope != prevScope) {
//reset all the weights
orig.values().stream().forEach(o -> o.resetWeight());
}
List<Integer> targetsInScope = getTargetsInScope(currentScope);
//We will adjust weights based off of the minimum load
double min = load == null ? 0 : targetsInScope.stream().mapToDouble((key) -> load.get(key)).min().getAsDouble();
for (int target : targetsInScope) {
IndexAndWeights val = orig.get(target);
double l = load == null ? 0.0 : load.get(target);
if (l <= min + (0.05)) {
//We assume that within 5% of the minimum congestion is still fine.
//Not congested we grow (but slowly)
val.weight = Math.min(MAX_WEIGHT, val.weight + 1);
} else {
//Congested we contract much more quickly
val.weight = Math.max(0, val.weight - 10);
}
}
//Now we need to build the array
long weightSum = targetsInScope.stream().mapToLong((target) -> orig.get(target).weight).sum();
//Now we can calculate a percentage
int currentIdx = 0;
if (weightSum > 0) {
for (int target : targetsInScope) {
IndexAndWeights indexAndWeights = orig.get(target);
int count = (int) ((indexAndWeights.weight / (double) weightSum) * capacity);
for (int i = 0; i < count && currentIdx < capacity; i++) {
prepareChoices[currentIdx] = indexAndWeights.index;
currentIdx++;
}
}
if (currentIdx > 0) {
//in case we didn't fill in enough
for (; currentIdx < capacity; currentIdx++) {
prepareChoices[currentIdx] = prepareChoices[random.nextInt(currentIdx)];
}
}
}
if (currentIdx == 0) {
//This really should be impossible, because we go off of the min load, and inc anything within 5% of it.
// But just to be sure it is never an issue, especially with float rounding etc.
for (; currentIdx < capacity; currentIdx++) {
prepareChoices[currentIdx] = currentIdx % rets.length;
}
}
shuffleArray(prepareChoices);
// swapping two arrays
int[] tempForSwap = choices;
choices = prepareChoices;
prepareChoices = tempForSwap;
current.set(-1);
}
private void shuffleArray(int[] arr) {
int size = arr.length;
for (int i = size; i > 1; i--) {
swap(arr, i - 1, random.nextInt(i));
}
}
private void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
private Scope calculateScope(Map<Integer, NodeInfo> taskToNodePort, Map<String, String> hostToRack, int target) {
NodeInfo targetNodeInfo = taskToNodePort.get(target);
if (targetNodeInfo == null) {
return Scope.EVERYTHING;
}
String sourceRack = hostToRack.get(sourceNodeInfo.get_node());
String targetRack = hostToRack.get(targetNodeInfo.get_node());
if (sourceRack != null && targetRack != null && sourceRack.equals(targetRack)) {
if (sourceNodeInfo.get_node().equals(targetNodeInfo.get_node())) {
if (sourceNodeInfo.get_port().equals(targetNodeInfo.get_port())) {
return Scope.WORKER_LOCAL;
}
return Scope.HOST_LOCAL;
}
return Scope.RACK_LOCAL;
} else {
return Scope.EVERYTHING;
}
}
private Map<String, String> getHostToRackMapping(Map<Integer, NodeInfo> taskToNodePort) {
Set<String> hosts = new HashSet();
for (int task : targetTasks) {
//if this task containing worker will be killed by a assignments sync,
//taskToNodePort will be an empty map which is refreshed by WorkerState
if (taskToNodePort.containsKey(task)) {
hosts.add(taskToNodePort.get(task).get_node());
} else {
LOG.error("Could not find task NodeInfo from local cache.");
}
}
hosts.add(sourceNodeInfo.get_node());
return dnsToSwitchMapping.resolve(new ArrayList<>(hosts));
}
//only for test
public int getCapacity() {
return capacity;
}
enum Scope {
WORKER_LOCAL, HOST_LOCAL, RACK_LOCAL, EVERYTHING;
public static Scope downgrade(Scope current) {
switch (current) {
case EVERYTHING:
return RACK_LOCAL;
case RACK_LOCAL:
return HOST_LOCAL;
case HOST_LOCAL:
case WORKER_LOCAL:
default:
return WORKER_LOCAL;
}
}
public static Scope upgrade(Scope current) {
switch (current) {
case WORKER_LOCAL:
return HOST_LOCAL;
case HOST_LOCAL:
return RACK_LOCAL;
case RACK_LOCAL:
case EVERYTHING:
default:
return EVERYTHING;
}
}
}
private static class IndexAndWeights {
final int index;
int weight;
IndexAndWeights(int index) {
this.index = index;
weight = MAX_WEIGHT;
}
void resetWeight() {
weight = MAX_WEIGHT;
}
}
}