blob: 3a3dc1fe22dbe0f65c55ee44d7d6517c7a0407fe [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.twitter.distributedlog.service.balancer;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.RateLimiter;
import com.twitter.distributedlog.client.monitor.MonitorServiceClient;
import com.twitter.distributedlog.service.ClientUtils;
import com.twitter.distributedlog.service.DLSocketAddress;
import com.twitter.distributedlog.service.DistributedLogClient;
import com.twitter.distributedlog.service.DistributedLogClientBuilder;
import com.twitter.util.Await;
import com.twitter.util.Function;
import com.twitter.util.Future;
import com.twitter.util.FutureEventListener;
import java.io.Serializable;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A balancer balances ownerships with a cluster of targets.
*/
public class ClusterBalancer implements Balancer {
private static final Logger logger = LoggerFactory.getLogger(ClusterBalancer.class);
/**
* Represent a single host. Ordered by number of streams in desc order.
*/
static class Host {
final SocketAddress address;
final Set<String> streams;
final DistributedLogClientBuilder clientBuilder;
DistributedLogClient client = null;
MonitorServiceClient monitor = null;
Host(SocketAddress address, Set<String> streams,
DistributedLogClientBuilder clientBuilder) {
this.address = address;
this.streams = streams;
this.clientBuilder = clientBuilder;
}
private void initializeClientsIfNeeded() {
if (null == client) {
Pair<DistributedLogClient, MonitorServiceClient> clientPair =
createDistributedLogClient(address, clientBuilder);
client = clientPair.getLeft();
monitor = clientPair.getRight();
}
}
synchronized DistributedLogClient getClient() {
initializeClientsIfNeeded();
return client;
}
synchronized MonitorServiceClient getMonitor() {
initializeClientsIfNeeded();
return monitor;
}
synchronized void close() {
if (null != client) {
client.close();
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Host(").append(address).append(")");
return sb.toString();
}
}
static class HostComparator implements Comparator<Host>, Serializable {
private static final long serialVersionUID = 7984973796525102538L;
@Override
public int compare(Host h1, Host h2) {
return h2.streams.size() - h1.streams.size();
}
}
protected final DistributedLogClientBuilder clientBuilder;
protected final DistributedLogClient client;
protected final MonitorServiceClient monitor;
public ClusterBalancer(DistributedLogClientBuilder clientBuilder) {
this(clientBuilder, ClientUtils.buildClient(clientBuilder));
}
ClusterBalancer(DistributedLogClientBuilder clientBuilder,
Pair<DistributedLogClient, MonitorServiceClient> clientPair) {
this.clientBuilder = clientBuilder;
this.client = clientPair.getLeft();
this.monitor = clientPair.getRight();
}
/**
* Build a new distributedlog client to a single host <i>host</i>.
*
* @param host
* host to access
* @return distributedlog clients
*/
static Pair<DistributedLogClient, MonitorServiceClient> createDistributedLogClient(
SocketAddress host, DistributedLogClientBuilder clientBuilder) {
DistributedLogClientBuilder newBuilder =
DistributedLogClientBuilder.newBuilder(clientBuilder).host(host);
return ClientUtils.buildClient(newBuilder);
}
@Override
public void balanceAll(String source,
int rebalanceConcurrency, /* unused */
Optional<RateLimiter> rebalanceRateLimiter) {
balance(0, 0.0f, rebalanceConcurrency, Optional.of(source), rebalanceRateLimiter);
}
@Override
public void balance(int rebalanceWaterMark,
double rebalanceTolerancePercentage,
int rebalanceConcurrency, /* unused */
Optional<RateLimiter> rebalanceRateLimiter) {
Optional<String> source = Optional.absent();
balance(rebalanceWaterMark, rebalanceTolerancePercentage, rebalanceConcurrency, source, rebalanceRateLimiter);
}
public void balance(int rebalanceWaterMark,
double rebalanceTolerancePercentage,
int rebalanceConcurrency,
Optional<String> source,
Optional<RateLimiter> rebalanceRateLimiter) {
Map<SocketAddress, Set<String>> distribution = monitor.getStreamOwnershipDistribution();
if (distribution.size() <= 1) {
return;
}
SocketAddress sourceAddr = null;
if (source.isPresent()) {
sourceAddr = DLSocketAddress.parseSocketAddress(source.get());
logger.info("Balancer source is {}", sourceAddr);
if (!distribution.containsKey(sourceAddr)) {
return;
}
}
// Get the list of hosts ordered by number of streams in DESC order
List<Host> hosts = new ArrayList<Host>(distribution.size());
for (Map.Entry<SocketAddress, Set<String>> entry : distribution.entrySet()) {
Host host = new Host(entry.getKey(), entry.getValue(), clientBuilder);
hosts.add(host);
}
Collections.sort(hosts, new HostComparator());
try {
// find the host to move streams from.
int hostIdxMoveFrom = -1;
if (null != sourceAddr) {
for (Host host : hosts) {
++hostIdxMoveFrom;
if (sourceAddr.equals(host.address)) {
break;
}
}
}
// compute the average load.
int totalStream = 0;
for (Host host : hosts) {
totalStream += host.streams.size();
}
double averageLoad;
if (hostIdxMoveFrom >= 0) {
averageLoad = ((double) totalStream / (hosts.size() - 1));
} else {
averageLoad = ((double) totalStream / hosts.size());
}
int moveFromLowWaterMark;
int moveToHighWaterMark =
Math.max(1, (int) (averageLoad + averageLoad * rebalanceTolerancePercentage / 100.0f));
if (hostIdxMoveFrom >= 0) {
moveFromLowWaterMark = Math.max(0, rebalanceWaterMark);
moveStreams(
hosts,
new AtomicInteger(hostIdxMoveFrom), moveFromLowWaterMark,
new AtomicInteger(hosts.size() - 1), moveToHighWaterMark,
rebalanceRateLimiter);
moveRemainingStreamsFromSource(hosts.get(hostIdxMoveFrom), hosts, rebalanceRateLimiter);
} else {
moveFromLowWaterMark = Math.max((int) Math.ceil(averageLoad), rebalanceWaterMark);
AtomicInteger moveFrom = new AtomicInteger(0);
AtomicInteger moveTo = new AtomicInteger(hosts.size() - 1);
while (moveFrom.get() < moveTo.get()) {
moveStreams(hosts, moveFrom, moveFromLowWaterMark,
moveTo, moveToHighWaterMark, rebalanceRateLimiter);
moveFrom.incrementAndGet();
}
}
} finally {
for (Host host : hosts) {
host.close();
}
}
}
void moveStreams(List<Host> hosts,
AtomicInteger hostIdxMoveFrom,
int moveFromLowWaterMark,
AtomicInteger hostIdxMoveTo,
int moveToHighWaterMark,
Optional<RateLimiter> rateLimiter) {
if (hostIdxMoveFrom.get() < 0 || hostIdxMoveFrom.get() >= hosts.size()
|| hostIdxMoveTo.get() < 0 || hostIdxMoveTo.get() >= hosts.size()
|| hostIdxMoveFrom.get() >= hostIdxMoveTo.get()) {
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Moving streams : hosts = {}, from = {}, to = {} :"
+ " from_low_water_mark = {}, to_high_water_mark = {}",
new Object[] {
hosts,
hostIdxMoveFrom.get(),
hostIdxMoveTo.get(),
moveFromLowWaterMark,
moveToHighWaterMark });
}
Host hostMoveFrom = hosts.get(hostIdxMoveFrom.get());
int numStreamsOnFromHost = hostMoveFrom.streams.size();
if (numStreamsOnFromHost <= moveFromLowWaterMark) {
// do nothing
return;
}
int numStreamsToMove = numStreamsOnFromHost - moveFromLowWaterMark;
LinkedList<String> streamsToMove = new LinkedList<String>(hostMoveFrom.streams);
Collections.shuffle(streamsToMove);
if (logger.isDebugEnabled()) {
logger.debug("Try to move {} streams from host {} : streams = {}",
new Object[] { numStreamsToMove, hostMoveFrom.address, streamsToMove });
}
while (numStreamsToMove-- > 0 && !streamsToMove.isEmpty()) {
if (rateLimiter.isPresent()) {
rateLimiter.get().acquire();
}
// pick a host to move
Host hostMoveTo = hosts.get(hostIdxMoveTo.get());
while (hostMoveTo.streams.size() >= moveToHighWaterMark) {
int hostIdx = hostIdxMoveTo.decrementAndGet();
logger.info("move to host : {}, from {}", hostIdx, hostIdxMoveFrom.get());
if (hostIdx <= hostIdxMoveFrom.get()) {
return;
} else {
hostMoveTo = hosts.get(hostIdx);
if (logger.isDebugEnabled()) {
logger.debug("Target host to move moved to host {} @ {}",
hostIdx, hostMoveTo);
}
}
}
// pick a stream
String stream = streamsToMove.remove();
// move the stream
if (moveStream(stream, hostMoveFrom, hostMoveTo)) {
hostMoveFrom.streams.remove(stream);
hostMoveTo.streams.add(stream);
}
}
}
void moveRemainingStreamsFromSource(Host source,
List<Host> hosts,
Optional<RateLimiter> rateLimiter) {
LinkedList<String> streamsToMove = new LinkedList<String>(source.streams);
Collections.shuffle(streamsToMove);
if (logger.isDebugEnabled()) {
logger.debug("Try to move remaining streams from {} : {}", source, streamsToMove);
}
int hostIdx = hosts.size() - 1;
while (!streamsToMove.isEmpty()) {
if (rateLimiter.isPresent()) {
rateLimiter.get().acquire();
}
Host target = hosts.get(hostIdx);
if (!target.address.equals(source.address)) {
String stream = streamsToMove.remove();
// move the stream
if (moveStream(stream, source, target)) {
source.streams.remove(stream);
target.streams.add(stream);
}
}
--hostIdx;
if (hostIdx < 0) {
hostIdx = hosts.size() - 1;
}
}
}
private boolean moveStream(String stream, Host from, Host to) {
try {
doMoveStream(stream, from, to);
return true;
} catch (Exception e) {
return false;
}
}
private void doMoveStream(final String stream, final Host from, final Host to) throws Exception {
logger.info("Moving stream {} from {} to {}.",
new Object[] { stream, from.address, to.address });
Await.result(from.getClient().release(stream).flatMap(new Function<Void, Future<Void>>() {
@Override
public Future<Void> apply(Void result) {
logger.info("Released stream {} from {}.", stream, from.address);
return to.getMonitor().check(stream).addEventListener(new FutureEventListener<Void>() {
@Override
public void onSuccess(Void value) {
logger.info("Moved stream {} from {} to {}.",
new Object[] { stream, from.address, to.address });
}
@Override
public void onFailure(Throwable cause) {
logger.info("Failed to move stream {} from {} to {} : ",
new Object[] { stream, from.address, to.address, cause });
}
});
}
}));
}
@Override
public void close() {
client.close();
}
}