blob: b2ac9ddd73554e6006f59daf1aaa3888de4bb197 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.shuffle.common;
import java.io.DataInputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import javax.crypto.SecretKey;
import javax.net.ssl.HttpsURLConnection;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.security.ssl.SSLFactory;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;
import org.apache.tez.runtime.library.shuffle.common.FetchedInput.Type;
import com.google.common.base.Preconditions;
/**
* Responsible for fetching inputs served by the ShuffleHandler for a single
* host. Construct using {@link FetcherBuilder}
*/
public class Fetcher implements Callable<FetchResult> {
private static final Log LOG = LogFactory.getLog(Fetcher.class);
private static final int UNIT_CONNECT_TIMEOUT = 60 * 1000;
private static final AtomicInteger fetcherIdGen = new AtomicInteger(0);
// Configurable fields.
private CompressionCodec codec;
private int connectionTimeout;
private int readTimeout;
private boolean ifileReadAhead = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_DEFAULT;
private int ifileReadAheadLength = TezJobConfig.TEZ_RUNTIME_IFILE_READAHEAD_BYTES_DEFAULT;
private final SecretKey shuffleSecret;
private final FetcherCallback fetcherCallback;
private final FetchedInputAllocator inputManager;
private final ApplicationId appId;
private static boolean sslShuffle = false;
private static SSLFactory sslFactory;
private static boolean sslFactoryInited;
private final int fetcherIdentifier;
// Parameters to track work.
private List<InputAttemptIdentifier> srcAttempts;
private String host;
private int port;
private int partition;
// Maps from the pathComponents (unique per srcTaskId) to the specific taskId
private final Map<String, InputAttemptIdentifier> pathToAttemptMap;
private LinkedHashSet<InputAttemptIdentifier> remaining;
private URL url;
private String encHash;
private String msgToEncode;
private Fetcher(FetcherCallback fetcherCallback,
FetchedInputAllocator inputManager, ApplicationId appId, SecretKey shuffleSecret,
Configuration conf) {
this.fetcherCallback = fetcherCallback;
this.inputManager = inputManager;
this.shuffleSecret = shuffleSecret;
this.appId = appId;
this.pathToAttemptMap = new HashMap<String, InputAttemptIdentifier>();
this.fetcherIdentifier = fetcherIdGen.getAndIncrement();
// TODO NEWTEZ Ideally, move this out from here into a static initializer block.
// Re-enable when ssl shuffle support is needed.
// synchronized (Fetcher.class) {
// if (!sslFactoryInited) {
// sslFactoryInited = true;
// sslShuffle = conf.getBoolean(
// TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
// TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL);
// if (sslShuffle) {
// sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf);
// try {
// sslFactory.init();
// } catch (Exception ex) {
// sslFactory.destroy();
// throw new RuntimeException(ex);
// }
// }
// }
// }
}
@Override
public FetchResult call() throws Exception {
if (srcAttempts.size() == 0) {
return new FetchResult(host, port, partition, srcAttempts);
}
for (InputAttemptIdentifier in : srcAttempts) {
pathToAttemptMap.put(in.getPathComponent(), in);
}
remaining = new LinkedHashSet<InputAttemptIdentifier>(srcAttempts);
HttpURLConnection connection;
try {
connection = connectToShuffleHandler(host, port, partition, srcAttempts);
} catch (IOException e) {
// ioErrs.increment(1);
// If connect did not succeed, just mark all the maps as failed,
// indirectly penalizing the host
for (Iterator<InputAttemptIdentifier> leftIter = remaining.iterator(); leftIter
.hasNext();) {
fetcherCallback.fetchFailed(host, leftIter.next(), true);
}
return new FetchResult(host, port, partition, remaining);
}
DataInputStream input;
try {
input = new DataInputStream(connection.getInputStream());
validateConnectionResponse(connection, url, msgToEncode, encHash);
} catch (IOException e) {
// ioErrs.increment(1);
// If we got a read error at this stage, it implies there was a problem
// with the first map, typically lost map. So, penalize only that map
// and add the rest
InputAttemptIdentifier firstAttempt = srcAttempts.get(0);
LOG.warn("Fetch Failure from host while connecting: " + host
+ ", attempt: " + firstAttempt + " Informing ShuffleManager: ", e);
fetcherCallback.fetchFailed(host, firstAttempt, false);
return new FetchResult(host, port, partition, remaining);
}
// By this point, the connection is setup and the response has been
// validated.
// Loop through available map-outputs and fetch them
// On any error, faildTasks is not null and we exit
// after putting back the remaining maps to the
// yet_to_be_fetched list and marking the failed tasks.
InputAttemptIdentifier[] failedInputs = null;
while (!remaining.isEmpty() && failedInputs == null) {
failedInputs = fetchInputs(input);
}
if (failedInputs != null && failedInputs.length > 0) {
LOG.warn("copyInputs failed for tasks " + Arrays.toString(failedInputs));
for (InputAttemptIdentifier left : failedInputs) {
fetcherCallback.fetchFailed(host, left, false);
}
}
IOUtils.cleanup(LOG, input);
// Sanity check
if (failedInputs == null && !remaining.isEmpty()) {
throw new IOException("server didn't return all expected map outputs: "
+ remaining.size() + " left.");
}
return new FetchResult(host, port, partition, remaining);
}
private InputAttemptIdentifier[] fetchInputs(DataInputStream input) {
FetchedInput fetchedInput = null;
InputAttemptIdentifier srcAttemptId = null;
long decompressedLength = -1;
long compressedLength = -1;
try {
long startTime = System.currentTimeMillis();
int responsePartition = -1;
// Read the shuffle header
String pathComponent = null;
try {
ShuffleHeader header = new ShuffleHeader();
header.readFields(input);
pathComponent = header.getMapId();
srcAttemptId = pathToAttemptMap.get(pathComponent);
compressedLength = header.getCompressedLength();
decompressedLength = header.getUncompressedLength();
responsePartition = header.getPartition();
} catch (IllegalArgumentException e) {
// badIdErrs.increment(1);
LOG.warn("Invalid src id ", e);
// Don't know which one was bad, so consider all of them as bad
return remaining.toArray(new InputAttemptIdentifier[remaining.size()]);
}
// Do some basic sanity verification
if (!verifySanity(compressedLength, decompressedLength,
responsePartition, srcAttemptId, pathComponent)) {
if (srcAttemptId == null) {
LOG.warn("Was expecting " + getNextRemainingAttempt() + " but got null");
srcAttemptId = getNextRemainingAttempt();
}
assert(srcAttemptId != null);
return new InputAttemptIdentifier[] { srcAttemptId };
}
if (LOG.isDebugEnabled()) {
LOG.debug("header: " + srcAttemptId + ", len: " + compressedLength
+ ", decomp len: " + decompressedLength);
}
// Get the location for the map output - either in-memory or on-disk
// TODO TEZ-957. handle IOException here when Broadcast has better error checking
fetchedInput = inputManager.allocate(decompressedLength, compressedLength, srcAttemptId);
// TODO NEWTEZ No concept of WAIT at the moment.
// // Check if we can shuffle *now* ...
// if (fetchedInput.getType() == FetchedInput.WAIT) {
// LOG.info("fetcher#" + id +
// " - MergerManager returned Status.WAIT ...");
// //Not an error but wait to process data.
// return EMPTY_ATTEMPT_ID_ARRAY;
// }
// Go!
LOG.info("fetcher" + " about to shuffle output of srcAttempt "
+ fetchedInput.getInputAttemptIdentifier() + " decomp: "
+ decompressedLength + " len: " + compressedLength + " to "
+ fetchedInput.getType());
if (fetchedInput.getType() == Type.MEMORY) {
ShuffleUtils.shuffleToMemory((MemoryFetchedInput) fetchedInput,
input, (int) decompressedLength, (int) compressedLength, codec,
ifileReadAhead, ifileReadAheadLength, LOG);
} else {
ShuffleUtils.shuffleToDisk((DiskFetchedInput) fetchedInput, input,
compressedLength, LOG);
}
// Inform the shuffle scheduler
long endTime = System.currentTimeMillis();
fetcherCallback.fetchSucceeded(host, srcAttemptId, fetchedInput,
compressedLength, decompressedLength, (endTime - startTime));
// Note successful shuffle
remaining.remove(srcAttemptId);
// metrics.successFetch();
return null;
} catch (IOException ioe) {
// ioErrs.increment(1);
if (srcAttemptId == null || fetchedInput == null) {
LOG.info("fetcher" + " failed to read map header" + srcAttemptId
+ " decomp: " + decompressedLength + ", " + compressedLength, ioe);
if (srcAttemptId == null) {
return remaining
.toArray(new InputAttemptIdentifier[remaining.size()]);
} else {
return new InputAttemptIdentifier[] { srcAttemptId };
}
}
LOG.warn("Failed to shuffle output of " + srcAttemptId + " from " + host,
ioe);
// Inform the shuffle-scheduler
try {
fetchedInput.abort();
} catch (IOException e) {
LOG.info("Failure to cleanup fetchedInput: " + fetchedInput);
}
// metrics.failedFetch();
return new InputAttemptIdentifier[] { srcAttemptId };
}
}
/**
* Do some basic verification on the input received -- Being defensive
*
* @param compressedLength
* @param decompressedLength
* @param fetchPartition
* @param remaining
* @param mapId
* @return true/false, based on if the verification succeeded or not
*/
private boolean verifySanity(long compressedLength, long decompressedLength,
int fetchPartition, InputAttemptIdentifier srcAttemptId, String pathComponent) {
if (compressedLength < 0 || decompressedLength < 0) {
// wrongLengthErrs.increment(1);
LOG.warn(" invalid lengths in input header -> headerPathComponent: "
+ pathComponent + ", nextRemainingSrcAttemptId: "
+ getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId
+ " len: " + compressedLength + ", decomp len: " + decompressedLength);
return false;
}
if (fetchPartition != this.partition) {
// wrongReduceErrs.increment(1);
LOG.warn(" data for the wrong reduce -> headerPathComponent: "
+ pathComponent + "nextRemainingSrcAttemptId: "
+ getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId
+ " len: " + compressedLength + " decomp len: " + decompressedLength
+ " for reduce " + fetchPartition);
return false;
}
// Sanity check
if (!remaining.contains(srcAttemptId)) {
// wrongMapErrs.increment(1);
LOG.warn("Invalid input. Received output for headerPathComponent: "
+ pathComponent + "nextRemainingSrcAttemptId: "
+ getNextRemainingAttempt() + ", mappedSrcAttemptId: " + srcAttemptId);
return false;
}
return true;
}
private InputAttemptIdentifier getNextRemainingAttempt() {
if (remaining.size() > 0) {
return remaining.iterator().next();
} else {
return null;
}
}
private HttpURLConnection connectToShuffleHandler(String host, int port,
int partition, List<InputAttemptIdentifier> inputs) throws IOException {
try {
this.url = constructInputURL(host, port, partition, inputs);
HttpURLConnection connection = openConnection(url);
// generate hash of the url
this.msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
this.encHash = SecureShuffleUtils.hashFromString(msgToEncode,
shuffleSecret);
// put url hash into http header
connection.addRequestProperty(SecureShuffleUtils.HTTP_HEADER_URL_HASH,
encHash);
// set the read timeout
connection.setReadTimeout(readTimeout);
// put shuffle version into http header
connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
connect(connection, connectionTimeout);
return connection;
} catch (IOException e) {
LOG.warn("Failed to connect to " + host + " with " + srcAttempts.size()
+ " inputs", e);
throw e;
}
}
private void validateConnectionResponse(HttpURLConnection connection,
URL url, String msgToEncode, String encHash) throws IOException {
int rc = connection.getResponseCode();
if (rc != HttpURLConnection.HTTP_OK) {
throw new IOException("Got invalid response code " + rc + " from " + url
+ ": " + connection.getResponseMessage());
}
if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(connection
.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
|| !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(connection
.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
throw new IOException("Incompatible shuffle response version");
}
// get the replyHash which is HMac of the encHash we sent to the server
String replyHash = connection
.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
if (replyHash == null) {
throw new IOException("security validation of TT Map output failed");
}
if (LOG.isDebugEnabled()) {
LOG.debug("url=" + msgToEncode + ";encHash=" + encHash + ";replyHash="
+ replyHash);
}
// verify that replyHash is HMac of encHash
SecureShuffleUtils.verifyReply(replyHash, encHash, shuffleSecret);
LOG.info("for url=" + msgToEncode + " sent hash and receievd reply");
}
protected HttpURLConnection openConnection(URL url) throws IOException {
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
if (sslShuffle) {
HttpsURLConnection httpsConn = (HttpsURLConnection) conn;
try {
httpsConn.setSSLSocketFactory(sslFactory.createSSLSocketFactory());
} catch (GeneralSecurityException ex) {
throw new IOException(ex);
}
httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier());
}
return conn;
}
/**
* The connection establishment is attempted multiple times and is given up
* only on the last failure. Instead of connecting with a timeout of X, we try
* connecting with a timeout of x < X but multiple times.
*/
private void connect(URLConnection connection, int connectionTimeout)
throws IOException {
int unit = 0;
if (connectionTimeout < 0) {
throw new IOException("Invalid timeout " + "[timeout = "
+ connectionTimeout + " ms]");
} else if (connectionTimeout > 0) {
unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout);
}
// set the connect timeout to the unit-connect-timeout
connection.setConnectTimeout(unit);
while (true) {
try {
connection.connect();
break;
} catch (IOException ioe) {
// update the total remaining connect-timeout
connectionTimeout -= unit;
// throw an exception if we have waited for timeout amount of time
// note that the updated value if timeout is used here
if (connectionTimeout == 0) {
throw ioe;
}
// reset the connect timeout for the last try
if (connectionTimeout < unit) {
unit = connectionTimeout;
// reset the connect time out for the final connect
connection.setConnectTimeout(unit);
}
}
}
}
private URL constructInputURL(String host, int port, int partition,
List<InputAttemptIdentifier> inputs) throws MalformedURLException {
StringBuilder url = ShuffleUtils.constructBaseURIForShuffleHandler(host,
port, partition, appId);
boolean first = true;
for (InputAttemptIdentifier input : inputs) {
if (first) {
first = false;
url.append(input.getPathComponent());
} else {
url.append(",").append(input.getPathComponent());
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("InputFetch URL for: " + host + " : " + url.toString());
}
return new URL(url.toString());
}
/**
* Builder for the construction of Fetchers
*/
public static class FetcherBuilder {
private Fetcher fetcher;
private boolean workAssigned = false;
public FetcherBuilder(FetcherCallback fetcherCallback,
FetchedInputAllocator inputManager, ApplicationId appId,
SecretKey shuffleSecret, Configuration conf) {
this.fetcher = new Fetcher(fetcherCallback, inputManager, appId,
shuffleSecret, conf);
}
public FetcherBuilder setCompressionParameters(CompressionCodec codec) {
fetcher.codec = codec;
return this;
}
public FetcherBuilder setConnectionParameters(int connectionTimeout,
int readTimeout) {
fetcher.connectionTimeout = connectionTimeout;
fetcher.readTimeout = readTimeout;
return this;
}
public FetcherBuilder setIFileParams(boolean readAhead, int readAheadBytes) {
fetcher.ifileReadAhead = readAhead;
fetcher.ifileReadAheadLength = readAheadBytes;
return this;
}
public FetcherBuilder assignWork(String host, int port, int partition,
List<InputAttemptIdentifier> inputs) {
fetcher.host = host;
fetcher.port = port;
fetcher.partition = partition;
fetcher.srcAttempts = inputs;
workAssigned = true;
return this;
}
public Fetcher build() {
Preconditions.checkState(workAssigned == true,
"Cannot build a fetcher withot assigning work to it");
return fetcher;
}
}
@Override
public int hashCode() {
return fetcherIdentifier;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Fetcher other = (Fetcher) obj;
if (fetcherIdentifier != other.fetcherIdentifier)
return false;
return true;
}
}