blob: 14c99b8b8b282d4d74203aa312809f04c56e7a3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.uniffle.shuffle.manager;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import static org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED;
public abstract class RssShuffleManagerBase implements RssShuffleManagerInterface, ShuffleManager {
private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
private AtomicBoolean isInitialized = new AtomicBoolean(false);
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;
/** See static overload of this method. */
public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);
/**
* Provides a task attempt id to be used in the block id, that is unique for a shuffle stage.
*
* <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
* that is unique across the entire Spark app which can reach very large numbers, which can
* practically reach LONG.MAX_VALUE. That would overflow the bits in the block id.
*
* <p>Here we use the map index or task id, appended by the attempt number per task. The map index
* is limited by the number of partitions of a stage. The attempt number per task is limited /
* configured by spark.task.maxFailures (default: 4).
*
* @return a task attempt id unique for a shuffle stage
*/
protected static long getTaskAttemptIdForBlockId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
// attempt number is zero based: 0, 1, …, maxFailures-1
// max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
// with speculative execution enabled we could observe +1 attempts
if (speculation) {
maxAttemptNo++;
}
if (attemptNo > maxAttemptNo) {
// this should never happen, if it does, our assumptions are wrong,
// and we risk overflowing the attempt number bits
throw new RssException(
"Observing attempt number "
+ attemptNo
+ " while maxFailures is set to "
+ maxFailures
+ (speculation ? " with speculation enabled" : "")
+ ".");
}
int attemptBits = 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
throw new RssException(
"Observing mapIndex["
+ mapIndex
+ "] that would produce a taskAttemptId with "
+ (mapIndexBits + attemptBits)
+ " bits which is larger than the allowed "
+ maxTaskAttemptIdBits
+ " bits (maxFailures["
+ maxFailures
+ "], speculation["
+ speculation
+ "]). Please consider providing more bits for taskAttemptIds.");
}
return (long) mapIndex << attemptBits | attemptNo;
}
@Override
public void unregisterAllMapOutput(int shuffleId) throws SparkException {
if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
return;
}
MapOutputTrackerMaster tracker = getMapOutputTrackerMaster();
if (isInitialized.compareAndSet(false, true)) {
unregisterAllMapOutputMethod = getUnregisterAllMapOutputMethod(tracker);
registerShuffleMethod = getRegisterShuffleMethod(tracker);
}
if (unregisterAllMapOutputMethod != null) {
try {
unregisterAllMapOutputMethod.invoke(tracker, shuffleId);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new RssException("Invoke unregisterAllMapOutput method failed", e);
}
} else {
int numMaps = getNumMaps(shuffleId);
int numReduces = getPartitionNum(shuffleId);
defaultUnregisterAllMapOutput(tracker, registerShuffleMethod, shuffleId, numMaps, numReduces);
}
}
private static void defaultUnregisterAllMapOutput(
MapOutputTrackerMaster tracker,
Method registerShuffle,
int shuffleId,
int numMaps,
int numReduces)
throws SparkException {
if (tracker != null && registerShuffle != null) {
tracker.unregisterShuffle(shuffleId);
// re-register this shuffle id into map output tracker
try {
if (SparkVersionUtils.MAJOR_VERSION > 3
|| (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2)) {
registerShuffle.invoke(tracker, shuffleId, numMaps, numReduces);
} else {
registerShuffle.invoke(tracker, shuffleId, numMaps);
}
} catch (InvocationTargetException | IllegalAccessException e) {
throw new RssException("Invoke registerShuffle method failed", e);
}
tracker.incrementEpoch();
} else {
throw new SparkException(
"default unregisterAllMapOutput should only be called on the driver side");
}
}
private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster tracker) {
if (tracker != null) {
Class<? extends MapOutputTrackerMaster> klass = tracker.getClass();
Method m = null;
try {
if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
// for spark version less than 2.3, there's no unregisterAllMapOutput support
LOG.warn("Spark version <= 2.3, fallback to default method");
} else if (SparkVersionUtils.isSpark2()) {
// this method is added in Spark 2.4+
m = klass.getDeclaredMethod("unregisterAllMapOutput", int.class);
} else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
// spark 3.1 will have unregisterAllMapOutput method
m = klass.getDeclaredMethod("unregisterAllMapOutput", int.class);
} else if (SparkVersionUtils.isSpark3()) {
m = klass.getDeclaredMethod("unregisterAllMapAndMergeOutput", int.class);
} else {
LOG.warn(
"Unknown spark version({}), fallback to default method",
SparkVersionUtils.SPARK_VERSION);
}
} catch (NoSuchMethodException e) {
LOG.warn(
"Got no such method error when get unregisterAllMapOutput method for spark version({})",
SparkVersionUtils.SPARK_VERSION);
}
return m;
} else {
return null;
}
}
private static Method getRegisterShuffleMethod(MapOutputTrackerMaster tracker) {
if (tracker != null) {
Class<? extends MapOutputTrackerMaster> klass = tracker.getClass();
Method m = null;
try {
if (SparkVersionUtils.MAJOR_VERSION > 3
|| (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2)) {
// for spark >= 3.2, the register shuffle method is changed to signature:
// registerShuffle(shuffleId, numMapTasks, numReduceTasks);
m = klass.getDeclaredMethod("registerShuffle", int.class, int.class, int.class);
} else {
m = klass.getDeclaredMethod("registerShuffle", int.class, int.class);
}
} catch (NoSuchMethodException e) {
LOG.warn(
"Got no such method error when get registerShuffle method for spark version({})",
SparkVersionUtils.SPARK_VERSION);
}
return m;
} else {
return null;
}
}
private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
MapOutputTracker tracker =
Optional.ofNullable(SparkEnv.get()).map(SparkEnv::mapOutputTracker).orElse(null);
return tracker instanceof MapOutputTrackerMaster ? (MapOutputTrackerMaster) tracker : null;
}
private static Map<String, String> parseRemoteStorageConf(Configuration conf) {
Map<String, String> confItems = Maps.newHashMap();
for (Map.Entry<String, String> entry : conf) {
confItems.put(entry.getKey(), entry.getValue());
}
return confItems;
}
protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkConf) {
Map<String, String> confItems = Maps.newHashMap();
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
if (rssConf.getBoolean(RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED)) {
confItems = parseRemoteStorageConf(new Configuration(true));
}
for (String key : rssConf.getKeySet()) {
if (key.startsWith(HADOOP_CONFIG_KEY_PREFIX)) {
String val = rssConf.getString(key, null);
if (val != null) {
String extractedKey = key.replaceFirst(HADOOP_CONFIG_KEY_PREFIX, "");
confItems.put(extractedKey, val);
}
}
}
return new RemoteStorageInfo(
sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), confItems);
}
}