blob: a682a09c7a668e54883eed213360cd26ce86a734 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.common.shuffle.impl;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TezInputContext;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.TezRuntimeUtils;
import com.google.common.collect.Lists;
class ShuffleScheduler {
static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
protected Long initialValue() {
return 0L;
}
};
private static final Log LOG = LogFactory.getLog(ShuffleScheduler.class);
private static final int MAX_MAPS_AT_ONCE = 20;
private static final long INITIAL_PENALTY = 10000;
private static final float PENALTY_GROWTH_RATE = 1.3f;
// TODO NEWTEZ May need to be a string if attempting to fetch from multiple inputs.
private final Map<Integer, MutableInt> finishedMaps;
private final int numInputs;
private int remainingMaps;
private Map<InputAttemptIdentifier, MapHost> mapLocations = new HashMap<InputAttemptIdentifier, MapHost>();
//TODO NEWTEZ Clean this and other maps at some point
private ConcurrentMap<String, InputAttemptIdentifier> pathToIdentifierMap = new ConcurrentHashMap<String, InputAttemptIdentifier>();
private Set<MapHost> pendingHosts = new HashSet<MapHost>();
private Set<InputAttemptIdentifier> obsoleteMaps = new HashSet<InputAttemptIdentifier>();
private final Random random = new Random(System.currentTimeMillis());
private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
private final Referee referee = new Referee();
private final Map<InputAttemptIdentifier, IntWritable> failureCounts =
new HashMap<InputAttemptIdentifier,IntWritable>();
private final Map<String,IntWritable> hostFailures =
new HashMap<String,IntWritable>();
private final TezInputContext inputContext;
private final Shuffle shuffle;
private final int abortFailureLimit;
private final TezCounter shuffledMapsCounter;
private final TezCounter reduceShuffleBytes;
private final TezCounter failedShuffleCounter;
private final long startTime;
private long lastProgressTime;
private int maxMapRuntime = 0;
private int maxFailedUniqueFetches = 5;
private int maxFetchFailuresBeforeReporting;
private long totalBytesShuffledTillNow = 0;
private DecimalFormat mbpsFormat = new DecimalFormat("0.00");
private boolean reportReadErrorImmediately = true;
public ShuffleScheduler(TezInputContext inputContext,
Configuration conf,
int tasksInDegree,
Shuffle shuffle,
TezCounter shuffledMapsCounter,
TezCounter reduceShuffleBytes,
TezCounter failedShuffleCounter) {
this.inputContext = inputContext;
this.numInputs = tasksInDegree;
abortFailureLimit = Math.max(30, tasksInDegree / 10);
remainingMaps = tasksInDegree;
//TODO NEWTEZ May need to be a string or a more usable construct if attempting to fetch from multiple inputs. Define a taskId / taskAttemptId pair
finishedMaps = new HashMap<Integer, MutableInt>(remainingMaps);
this.shuffle = shuffle;
this.shuffledMapsCounter = shuffledMapsCounter;
this.reduceShuffleBytes = reduceShuffleBytes;
this.failedShuffleCounter = failedShuffleCounter;
this.startTime = System.currentTimeMillis();
this.lastProgressTime = startTime;
referee.start();
this.maxFailedUniqueFetches = Math.min(tasksInDegree,
this.maxFailedUniqueFetches);
this.maxFetchFailuresBeforeReporting =
conf.getInt(
TezJobConfig.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES,
TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
this.reportReadErrorImmediately =
conf.getBoolean(
TezJobConfig.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR,
TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
}
public synchronized void copySucceeded(InputAttemptIdentifier srcAttemptIdentifier,
MapHost host,
long bytes,
long milis,
MapOutput output
) throws IOException {
String taskIdentifier = TezRuntimeUtils.getTaskAttemptIdentifier(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(), srcAttemptIdentifier.getAttemptNumber());
failureCounts.remove(taskIdentifier);
hostFailures.remove(host.getHostName());
if (!isFinishedTaskTrue(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
output.commit();
if(incrementTaskCopyAndCheckCompletion(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
shuffledMapsCounter.increment(1);
if (--remainingMaps == 0) {
notifyAll();
}
}
// update the status
lastProgressTime = System.currentTimeMillis();
totalBytesShuffledTillNow += bytes;
logProgress();
reduceShuffleBytes.increment(bytes);
if (LOG.isDebugEnabled()) {
LOG.debug("src task: "
+ TezRuntimeUtils.getTaskAttemptIdentifier(
inputContext.getSourceVertexName(), srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(),
srcAttemptIdentifier.getAttemptNumber()) + " done");
}
}
// TODO NEWTEZ Should this be releasing the output, if not committed ? Possible memory leak in case of speculation.
}
private void logProgress() {
float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
int mapsDone = numInputs - remainingMaps;
long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;
float transferRate = mbs / secsSinceStart;
LOG.info("copy(" + mapsDone + " of " + numInputs + " at "
+ mbpsFormat.format(transferRate) + " MB/s)");
}
public synchronized void copyFailed(InputAttemptIdentifier srcAttempt,
MapHost host,
boolean readError) {
host.penalize();
int failures = 1;
if (failureCounts.containsKey(srcAttempt)) {
IntWritable x = failureCounts.get(srcAttempt);
x.set(x.get() + 1);
failures = x.get();
} else {
failureCounts.put(srcAttempt, new IntWritable(1));
}
String hostname = host.getHostName();
if (hostFailures.containsKey(hostname)) {
IntWritable x = hostFailures.get(hostname);
x.set(x.get() + 1);
} else {
hostFailures.put(hostname, new IntWritable(1));
}
if (failures >= abortFailureLimit) {
try {
throw new IOException(failures
+ " failures downloading "
+ TezRuntimeUtils.getTaskAttemptIdentifier(
inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
srcAttempt.getAttemptNumber()));
} catch (IOException ie) {
shuffle.reportException(ie);
}
}
checkAndInformJobTracker(failures, srcAttempt, readError);
checkReducerHealth();
long delay = (long) (INITIAL_PENALTY *
Math.pow(PENALTY_GROWTH_RATE, failures));
penalties.add(new Penalty(host, delay));
failedShuffleCounter.increment(1);
}
// Notify the JobTracker
// after every read error, if 'reportReadErrorImmediately' is true or
// after every 'maxFetchFailuresBeforeReporting' failures
private void checkAndInformJobTracker(
int failures, InputAttemptIdentifier srcAttempt, boolean readError) {
if ((reportReadErrorImmediately && readError)
|| ((failures % maxFetchFailuresBeforeReporting) == 0)) {
LOG.info("Reporting fetch failure for "
+ TezRuntimeUtils.getTaskAttemptIdentifier(
inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
srcAttempt.getAttemptNumber()) + " to jobtracker.");
List<Event> failedEvents = Lists.newArrayListWithCapacity(1);
failedEvents.add(new InputReadErrorEvent("Fetch failure for "
+ TezRuntimeUtils.getTaskAttemptIdentifier(
inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
srcAttempt.getAttemptNumber()) + " to jobtracker.", srcAttempt.getInputIdentifier()
.getSrcTaskIndex(), srcAttempt.getAttemptNumber()));
inputContext.sendEvents(failedEvents);
//status.addFailedDependency(mapId);
}
}
private void checkReducerHealth() {
final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;
long totalFailures = failedShuffleCounter.getValue();
int doneMaps = numInputs - remainingMaps;
boolean reducerHealthy =
(((float)totalFailures / (totalFailures + doneMaps))
< MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
// check if the reducer has progressed enough
boolean reducerProgressedEnough =
(((float)doneMaps / numInputs)
>= MIN_REQUIRED_PROGRESS_PERCENT);
// check if the reducer is stalled for a long time
// duration for which the reducer is stalled
int stallDuration =
(int)(System.currentTimeMillis() - lastProgressTime);
// duration for which the reducer ran with progress
int shuffleProgressDuration =
(int)(lastProgressTime - startTime);
// min time the reducer should run without getting killed
int minShuffleRunDuration =
(shuffleProgressDuration > maxMapRuntime)
? shuffleProgressDuration
: maxMapRuntime;
boolean reducerStalled =
(((float)stallDuration / minShuffleRunDuration)
>= MAX_ALLOWED_STALL_TIME_PERCENT);
// kill if not healthy and has insufficient progress
if ((failureCounts.size() >= maxFailedUniqueFetches ||
failureCounts.size() == (numInputs - doneMaps))
&& !reducerHealthy
&& (!reducerProgressedEnough || reducerStalled)) {
LOG.fatal("Shuffle failed with too many fetch failures " +
"and insufficient progress!");
String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
shuffle.reportException(new IOException(errorMsg));
}
}
public synchronized void tipFailed(int srcTaskIndex) {
if (!isFinishedTaskTrue(srcTaskIndex)) {
setFinishedTaskTrue(srcTaskIndex);
if (--remainingMaps == 0) {
notifyAll();
}
logProgress();
}
}
public synchronized void addKnownMapOutput(String hostName,
int partitionId,
String hostUrl,
InputAttemptIdentifier srcAttempt) {
String identifier = MapHost.createIdentifier(hostName, partitionId);
MapHost host = mapLocations.get(identifier);
if (host == null) {
host = new MapHost(partitionId, hostName, hostUrl);
assert identifier.equals(host.getIdentifier());
mapLocations.put(srcAttempt, host);
}
host.addKnownMap(srcAttempt);
pathToIdentifierMap.put(srcAttempt.getPathComponent(), srcAttempt);
// Mark the host as pending
if (host.getState() == MapHost.State.PENDING) {
pendingHosts.add(host);
notifyAll();
}
}
public synchronized void obsoleteMapOutput(InputAttemptIdentifier srcAttempt) {
// The incoming srcAttempt does not contain a path component.
obsoleteMaps.add(srcAttempt);
}
public synchronized void putBackKnownMapOutput(MapHost host,
InputAttemptIdentifier srcAttempt) {
host.addKnownMap(srcAttempt);
}
public synchronized MapHost getHost() throws InterruptedException {
while(pendingHosts.isEmpty()) {
wait();
}
MapHost host = null;
Iterator<MapHost> iter = pendingHosts.iterator();
int numToPick = random.nextInt(pendingHosts.size());
for (int i=0; i <= numToPick; ++i) {
host = iter.next();
}
pendingHosts.remove(host);
host.markBusy();
LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
" to " + Thread.currentThread().getName());
shuffleStart.set(System.currentTimeMillis());
return host;
}
public InputAttemptIdentifier getIdentifierForPathComponent(String pathComponent) {
return pathToIdentifierMap.get(pathComponent);
}
public synchronized List<InputAttemptIdentifier> getMapsForHost(MapHost host) {
List<InputAttemptIdentifier> list = host.getAndClearKnownMaps();
Iterator<InputAttemptIdentifier> itr = list.iterator();
List<InputAttemptIdentifier> result = new ArrayList<InputAttemptIdentifier>();
int includedMaps = 0;
int totalSize = list.size();
// find the maps that we still need, up to the limit
while (itr.hasNext()) {
InputAttemptIdentifier id = itr.next();
if (!obsoleteMaps.contains(id) && !isFinishedTaskTrue(id.getInputIdentifier().getSrcTaskIndex())) {
result.add(id);
if (++includedMaps >= MAX_MAPS_AT_ONCE) {
break;
}
}
}
// put back the maps left after the limit
while (itr.hasNext()) {
InputAttemptIdentifier id = itr.next();
if (!obsoleteMaps.contains(id) && !isFinishedTaskTrue(id.getInputIdentifier().getSrcTaskIndex())) {
host.addKnownMap(id);
}
}
LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
host + " to " + Thread.currentThread().getName());
return result;
}
public synchronized void freeHost(MapHost host) {
if (host.getState() != MapHost.State.PENALIZED) {
if (host.markAvailable() == MapHost.State.PENDING) {
pendingHosts.add(host);
notifyAll();
}
}
LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
(System.currentTimeMillis()-shuffleStart.get()) + "s");
}
public synchronized void resetKnownMaps() {
mapLocations.clear();
obsoleteMaps.clear();
pendingHosts.clear();
pathToIdentifierMap.clear();
}
/**
* Utility method to check if the Shuffle data fetch is complete.
* @return
*/
public synchronized boolean isDone() {
return remainingMaps == 0;
}
/**
* Wait until the shuffle finishes or until the timeout.
* @param millis maximum wait time
* @return true if the shuffle is done
* @throws InterruptedException
*/
public synchronized boolean waitUntilDone(int millis
) throws InterruptedException {
if (remainingMaps > 0) {
wait(millis);
return remainingMaps == 0;
}
return true;
}
/**
* A structure that records the penalty for a host.
*/
private static class Penalty implements Delayed {
MapHost host;
private long endTime;
Penalty(MapHost host, long delay) {
this.host = host;
this.endTime = System.currentTimeMillis() + delay;
}
public long getDelay(TimeUnit unit) {
long remainingTime = endTime - System.currentTimeMillis();
return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
}
public int compareTo(Delayed o) {
long other = ((Penalty) o).endTime;
return endTime == other ? 0 : (endTime < other ? -1 : 1);
}
}
/**
* A thread that takes hosts off of the penalty list when the timer expires.
*/
private class Referee extends Thread {
public Referee() {
setName("ShufflePenaltyReferee");
setDaemon(true);
}
public void run() {
try {
while (true) {
// take the first host that has an expired penalty
MapHost host = penalties.take().host;
synchronized (ShuffleScheduler.this) {
if (host.markAvailable() == MapHost.State.PENDING) {
pendingHosts.add(host);
ShuffleScheduler.this.notifyAll();
}
}
}
} catch (InterruptedException ie) {
return;
} catch (Throwable t) {
shuffle.reportException(t);
}
}
}
public void close() throws InterruptedException {
referee.interrupt();
referee.join();
}
public synchronized void informMaxMapRunTime(int duration) {
if (duration > maxMapRuntime) {
maxMapRuntime = duration;
}
}
void setFinishedTaskTrue(int srcTaskIndex) {
synchronized(finishedMaps) {
finishedMaps.put(srcTaskIndex, new MutableInt(shuffle.getReduceRange()));
}
}
boolean incrementTaskCopyAndCheckCompletion(int srcTaskIndex) {
synchronized(finishedMaps) {
MutableInt result = finishedMaps.get(srcTaskIndex);
if(result == null) {
result = new MutableInt(0);
finishedMaps.put(srcTaskIndex, result);
}
result.increment();
return isFinishedTaskTrue(srcTaskIndex);
}
}
boolean isFinishedTaskTrue(int srcTaskIndex) {
synchronized (finishedMaps) {
MutableInt result = finishedMaps.get(srcTaskIndex);
if(result == null) {
return false;
}
if (result.intValue() == shuffle.getReduceRange()) {
return true;
}
return false;
}
}
}