[#819 ] feat(tez): Tez ApplicationMaster supporting RemoteShuffle (#918)
### What changes were proposed in this pull request?
1. RssDAGAppMaster will start rpc server and process getShuffleAssignments requests(assign uniffle works) from map/reduce task to shuffle write/read
2. Keep those uniffle works heatbeat with coordinator in application lifecycle
### Why are the changes needed?
Fix: #819
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT
Co-authored-by: qiujiyang <knight.yang@huolala.cn>
diff --git a/client-tez/pom.xml b/client-tez/pom.xml
index 4188c1f..efa3141 100644
--- a/client-tez/pom.xml
+++ b/client-tez/pom.xml
@@ -41,6 +41,12 @@
<dependency>
<groupId>org.apache.tez</groupId>
<artifactId>tez-runtime-library</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-auth</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.apache.tez</groupId>
diff --git a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
new file mode 100644
index 0000000..65da375
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.GnuParser;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.mapreduce.JobSubmissionFiles;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.util.ShutdownHookManager;
+import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.util.Clock;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.hadoop.yarn.util.SystemClock;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.TezClassLoader;
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.common.VersionInfo;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.api.records.DAGProtos.AMPluginDescriptorProto;
+import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ThreadUtils;
+
+import static org.apache.tez.common.TezCommonUtils.TEZ_SYSTEM_SUB_DIR;
+
+public class RssDAGAppMaster extends DAGAppMaster {
+ private static final Logger LOG = LoggerFactory.getLogger(RssDAGAppMaster.class);
+ private ShuffleWriteClient shuffleWriteClient;
+ private TezRemoteShuffleManager tezRemoteShuffleManager;
+ private static final String rssConfFileLocalResourceName = "rss_conf.xml";
+
+ private DAGProtos.PlanLocalResource rssConfFileLocalResource;
+ final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
+ ThreadUtils.getThreadFactory("AppHeartbeat-%d")
+ );
+
+ public RssDAGAppMaster(ApplicationAttemptId applicationAttemptId, ContainerId containerId,
+ String nmHost, int nmPort, int nmHttpPort, Clock clock, long appSubmitTime, boolean isSession,
+ String workingDirectory, String[] localDirs, String[] logDirs, String clientVersion,
+ Credentials credentials, String jobUserName, AMPluginDescriptorProto pluginDescriptorProto) {
+ super(applicationAttemptId, containerId, nmHost, nmPort, nmHttpPort, clock, appSubmitTime, isSession,
+ workingDirectory, localDirs, logDirs, clientVersion, credentials, jobUserName, pluginDescriptorProto);
+ }
+
+ public ShuffleWriteClient getShuffleWriteClient() {
+ return shuffleWriteClient;
+ }
+
+ public void setShuffleWriteClient(ShuffleWriteClient shuffleWriteClient) {
+ this.shuffleWriteClient = shuffleWriteClient;
+ }
+
+ public TezRemoteShuffleManager getTezRemoteShuffleManager() {
+ return tezRemoteShuffleManager;
+ }
+
+ public void setTezRemoteShuffleManager(TezRemoteShuffleManager tezRemoteShuffleManager) {
+ this.tezRemoteShuffleManager = tezRemoteShuffleManager;
+ }
+
+ /**
+ * Init and Start Rss Client
+ * @param appMaster
+ * @param conf
+ * @param applicationAttemptId
+ * @throws Exception
+ */
+ public static void initAndStartRSSClient(final RssDAGAppMaster appMaster, Configuration conf,
+ ApplicationAttemptId applicationAttemptId) throws Exception {
+
+ ShuffleWriteClient client = RssTezUtils.createShuffleClient(conf);
+ appMaster.setShuffleWriteClient(client);
+
+ String coordinators = conf.get(RssTezConfig.RSS_COORDINATOR_QUORUM);
+ LOG.info("Registering coordinators {}", coordinators);
+ client.registerCoordinators(coordinators);
+
+ String strAppAttemptId = applicationAttemptId.toString();
+ long heartbeatInterval = conf.getLong(RssTezConfig.RSS_HEARTBEAT_INTERVAL,
+ RssTezConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
+ long heartbeatTimeout = conf.getLong(RssTezConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
+ client.registerApplicationInfo(strAppAttemptId, heartbeatTimeout, "user");
+
+ appMaster.scheduledExecutorService.scheduleAtFixedRate(
+ () -> {
+ try {
+ client.sendAppHeartbeat(strAppAttemptId, heartbeatTimeout);
+ LOG.debug("Finish send heartbeat to coordinator and servers");
+ } catch (Exception e) {
+ LOG.warn("Fail to send heartbeat to coordinator and servers", e);
+ }
+ },
+ heartbeatInterval / 2,
+ heartbeatInterval,
+ TimeUnit.MILLISECONDS);
+
+ appMaster.setTezRemoteShuffleManager(new TezRemoteShuffleManager(strAppAttemptId, null, conf,
+ strAppAttemptId, client));
+ appMaster.getTezRemoteShuffleManager().initialize();
+ appMaster.getTezRemoteShuffleManager().start();
+
+ TezConfiguration extraConf = new TezConfiguration(false);
+ extraConf.clear();
+
+ String strAppId = applicationAttemptId.getApplicationId().toString();
+ extraConf.set(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS,
+ appMaster.getTezRemoteShuffleManager().address.getHostName());
+ extraConf.setInt(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT,
+ appMaster.getTezRemoteShuffleManager().address.getPort());
+ writeExtraConf(appMaster, conf, extraConf, strAppId);
+
+ mayCloseTezSlowStart(conf);
+ }
+
+ @Override
+ public String submitDAGToAppMaster(DAGProtos.DAGPlan dagPlan, Map<String, LocalResource> additionalResources)
+ throws TezException {
+
+ addAdditionalResource(dagPlan, getRssConfFileLocalResource());
+
+ return super.submitDAGToAppMaster(dagPlan, additionalResources);
+ }
+
+ public DAGProtos.PlanLocalResource getRssConfFileLocalResource() {
+ return rssConfFileLocalResource;
+ }
+
+ public static void addAdditionalResource(DAGProtos.DAGPlan dagPlan, DAGProtos.PlanLocalResource additionalResource)
+ throws TezException {
+ List<DAGProtos.PlanLocalResource> planLocalResourceList = dagPlan.getLocalResourceList();
+
+ if (planLocalResourceList == null) {
+ LOG.warn("planLocalResourceList is null, add new list");
+ planLocalResourceList = new ArrayList<>();
+ } else {
+ planLocalResourceList = new ArrayList<>(planLocalResourceList);
+ }
+
+ try {
+ planLocalResourceList.add(additionalResource);
+ Field field = DAGProtos.DAGPlan.class.getDeclaredField("localResource_");
+ field.setAccessible(true);
+ field.set(dagPlan, planLocalResourceList);
+ field.setAccessible(false);
+ } catch (Exception e) {
+ LOG.error("submitDAGToAppMaster reflect error", e);
+ throw new TezException(e.getMessage());
+ }
+
+ if (LOG.isDebugEnabled()) {
+ for (DAGProtos.PlanLocalResource localResource : dagPlan.getLocalResourceList()) {
+ LOG.debug("localResource: {}", localResource.toString());
+ }
+ }
+ }
+
+ /**
+ * main method
+ * @param args
+ */
+ public static void main(String[] args) {
+ try {
+ // Install the tez class loader, which can be used add new resources
+ TezClassLoader.setupTezClassLoader();
+ Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler());
+ final String pid = System.getenv().get("JVM_PID");
+ String containerIdStr =
+ System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name());
+ String appSubmitTimeStr =
+ System.getenv(ApplicationConstants.APP_SUBMIT_TIME_ENV);
+ String clientVersion = System.getenv(TezConstants.TEZ_CLIENT_VERSION_ENV);
+ if (clientVersion == null) {
+ clientVersion = VersionInfo.UNKNOWN;
+ }
+
+ Objects.requireNonNull(appSubmitTimeStr,
+ ApplicationConstants.APP_SUBMIT_TIME_ENV + " is null");
+
+ ContainerId containerId = ConverterUtils.toContainerId(containerIdStr);
+ ApplicationAttemptId applicationAttemptId = containerId.getApplicationAttemptId();
+
+ String jobUserName = System
+ .getenv(ApplicationConstants.Environment.USER.name());
+
+ // Command line options
+ Options opts = new Options();
+ opts.addOption(TezConstants.TEZ_SESSION_MODE_CLI_OPTION,
+ false, "Run Tez Application Master in Session mode");
+
+ CommandLine cliParser = new GnuParser().parse(opts, args);
+ boolean sessionModeCliOption = cliParser.hasOption(TezConstants.TEZ_SESSION_MODE_CLI_OPTION);
+
+ LOG.info("Creating DAGAppMaster for "
+ + "applicationId=" + applicationAttemptId.getApplicationId()
+ + ", attemptNum=" + applicationAttemptId.getAttemptId()
+ + ", AMContainerId=" + containerId
+ + ", jvmPid=" + pid
+ + ", userFromEnv=" + jobUserName
+ + ", cliSessionOption=" + sessionModeCliOption
+ + ", pwd=" + System.getenv(ApplicationConstants.Environment.PWD.name())
+ + ", localDirs=" + System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())
+ + ", logDirs=" + System.getenv(ApplicationConstants.Environment.LOG_DIRS.name()));
+
+ Configuration conf = new Configuration(new YarnConfiguration());
+
+ DAGProtos.ConfigurationProto confProto = TezUtilsInternal.readUserSpecifiedTezConfiguration(
+ System.getenv(ApplicationConstants.Environment.PWD.name()));
+ TezUtilsInternal.addUserSpecifiedTezConfiguration(conf, confProto.getConfKeyValuesList());
+
+ AMPluginDescriptorProto amPluginDescriptorProto = null;
+ if (confProto.hasAmPluginDescriptor()) {
+ amPluginDescriptorProto = confProto.getAmPluginDescriptor();
+ }
+
+ UserGroupInformation.setConfiguration(conf);
+ Credentials credentials = UserGroupInformation.getCurrentUser().getCredentials();
+
+ TezUtilsInternal.setSecurityUtilConfigration(LOG, conf);
+
+ String nodeHostString = System.getenv(ApplicationConstants.Environment.NM_HOST.name());
+ String nodePortString = System.getenv(ApplicationConstants.Environment.NM_PORT.name());
+ String nodeHttpPortString =
+ System.getenv(ApplicationConstants.Environment.NM_HTTP_PORT.name());
+ long appSubmitTime = Long.parseLong(appSubmitTimeStr);
+ RssDAGAppMaster appMaster =
+ new RssDAGAppMaster(applicationAttemptId, containerId, nodeHostString,
+ Integer.parseInt(nodePortString),
+ Integer.parseInt(nodeHttpPortString), new SystemClock(), appSubmitTime,
+ sessionModeCliOption,
+ System.getenv(ApplicationConstants.Environment.PWD.name()),
+ TezCommonUtils.getTrimmedStrings(
+ System.getenv(ApplicationConstants.Environment.LOCAL_DIRS.name())),
+ TezCommonUtils.getTrimmedStrings(System.getenv(ApplicationConstants.Environment.LOG_DIRS.name())),
+ clientVersion, credentials, jobUserName, amPluginDescriptorProto);
+ ShutdownHookManager.get().addShutdownHook(
+ new DAGAppMasterShutdownHook(appMaster), SHUTDOWN_HOOK_PRIORITY);
+ ShutdownHookManager.get().addShutdownHook(new RssDAGAppMasterShutdownHook(appMaster), SHUTDOWN_HOOK_PRIORITY);
+
+
+ // log the system properties
+ if (LOG.isInfoEnabled()) {
+ String systemPropsToLog = TezCommonUtils.getSystemPropertiesToLog(conf);
+ if (systemPropsToLog != null) {
+ LOG.info(systemPropsToLog);
+ }
+ }
+
+ initAndStartRSSClient(appMaster, conf, applicationAttemptId);
+ initAndStartAppMaster(appMaster, conf);
+ } catch (Throwable t) {
+ LOG.error("Error starting DAGAppMaster", t);
+ System.exit(1);
+ }
+ }
+
+ static void writeExtraConf(final RssDAGAppMaster appMaster, Configuration conf,
+ TezConfiguration extraConf, String strAppId) {
+ try {
+ Path baseStagingPath = TezCommonUtils.getTezBaseStagingPath(conf);
+ Path tezStagingDir = new Path(new Path(baseStagingPath, TEZ_SYSTEM_SUB_DIR), strAppId);
+
+ FileSystem fs = tezStagingDir.getFileSystem(conf);
+ Path rssConfFilePath = new Path(tezStagingDir, RssTezConfig.RSS_CONF_FILE);
+
+ try (FSDataOutputStream out =
+ FileSystem.create(fs, rssConfFilePath,
+ new FsPermission(JobSubmissionFiles.JOB_FILE_PERMISSION))) {
+ extraConf.writeXml(out);
+ }
+ FileStatus rsrcStat = fs.getFileStatus(rssConfFilePath);
+
+ appMaster.rssConfFileLocalResource = DAGProtos.PlanLocalResource.newBuilder()
+ .setName(appMaster.rssConfFileLocalResourceName)
+ .setUri(rsrcStat.getPath().toString())
+ .setSize(rsrcStat.getLen())
+ .setTimeStamp(rsrcStat.getModificationTime())
+ .setType(DAGProtos.PlanLocalResourceType.FILE)
+ .setVisibility(DAGProtos.PlanLocalResourceVisibility.APPLICATION)
+ .build();
+ LOG.info("Upload extra conf success!");
+ } catch (Exception e) {
+ LOG.error("Upload extra conf exception!", e);
+ throw new RssException("Upload extra conf exception ", e);
+ }
+ }
+
+ static void mayCloseTezSlowStart(Configuration conf) {
+ if (!conf.getBoolean(RssTezConfig.RSS_AM_SLOW_START_ENABLE, RssTezConfig.RSS_AM_SLOW_START_ENABLE_DEFAULT)) {
+ conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 1.0f);
+ conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 1.0f);
+ }
+ }
+
+ static class RssDAGAppMasterShutdownHook implements Runnable {
+ RssDAGAppMaster appMaster;
+
+ RssDAGAppMasterShutdownHook(RssDAGAppMaster appMaster) {
+ this.appMaster = appMaster;
+ }
+
+ @Override
+ public void run() {
+ if (appMaster.shuffleWriteClient != null) {
+ appMaster.shuffleWriteClient.close();
+ }
+
+ if (appMaster.tezRemoteShuffleManager != null) {
+ try {
+ appMaster.tezRemoteShuffleManager.shutdown();
+ } catch (Exception e) {
+ RssDAGAppMaster.LOG.info("TezRemoteShuffleManager shutdown error: " + e.getMessage());
+ }
+ }
+
+ RssDAGAppMaster.LOG.info("MRAppMaster received a signal. Signaling RMCommunicator and JobHistoryEventHandler.");
+ this.appMaster.stop();
+ }
+ }
+}
diff --git a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
new file mode 100644
index 0000000..2fe9dc3
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.Sets;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.apache.hadoop.ipc.ProtocolSignature;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.ipc.Server;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.authorize.PolicyProvider;
+import org.apache.hadoop.security.token.Token;
+import org.apache.tez.common.GetShuffleServerRequest;
+import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.RssTezUtils;
+import org.apache.tez.common.ServicePluginLifecycle;
+import org.apache.tez.common.ShuffleAssignmentsInfoWritable;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.util.ClientUtils;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleDataDistributionType;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RetryUtils;
+
+import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
+
+public class TezRemoteShuffleManager implements ServicePluginLifecycle {
+ private static final Logger LOG = LoggerFactory.getLogger(TezRemoteShuffleManager.class);
+
+ protected InetSocketAddress address;
+
+ protected volatile Server server;
+ private String tokenIdentifier;
+ private Token<JobTokenIdentifier> sessionToken;
+ private Configuration conf;
+ private TezRemoteShuffleUmbilicalProtocolImpl tezRemoteShuffleUmbilical;
+ private ShuffleWriteClient rssClient;
+ private String appId;
+
+ public TezRemoteShuffleManager(String tokenIdentifier, Token<JobTokenIdentifier> sessionToken,
+ Configuration conf, String appId, ShuffleWriteClient rssClient) {
+ this.tokenIdentifier = tokenIdentifier;
+ this.sessionToken = sessionToken;
+ this.conf = conf;
+ this.appId = appId;
+ this.rssClient = rssClient;
+ this.tezRemoteShuffleUmbilical = new TezRemoteShuffleUmbilicalProtocolImpl();
+ }
+
+ @Override
+ public void initialize() throws Exception {
+
+ }
+
+ @Override
+ public void start() throws Exception {
+ startRpcServer();
+ }
+
+ @Override
+ public void shutdown() throws Exception {
+ server.stop();
+ }
+
+ private class TezRemoteShuffleUmbilicalProtocolImpl implements TezRemoteShuffleUmbilicalProtocol {
+ private Map<Integer, ShuffleAssignmentsInfo> shuffleIdToShuffleAssignsInfo = new HashMap<>();
+
+ @Override
+ public long getProtocolVersion(String s, long l) throws IOException {
+ return versionID;
+ }
+
+ @Override
+ public ProtocolSignature getProtocolSignature(String protocol, long clientVersion,
+ int clientMethodsHash) throws IOException {
+ return ProtocolSignature.getProtocolSignature(this, protocol,
+ clientVersion, clientMethodsHash);
+ }
+
+ @Override
+ public GetShuffleServerResponse getShuffleAssignments(GetShuffleServerRequest request)
+ throws IOException, TezException {
+
+ GetShuffleServerResponse response = new GetShuffleServerResponse();
+ if (request != null) {
+ LOG.info("getShuffleAssignments with request = " + request);
+ } else {
+ LOG.error("getShuffleAssignments with request is null");
+ response.setStatus(-1);
+ response.setRetMsg("GetShuffleServerRequest is null");
+ return response;
+ }
+
+ int shuffleId = request.getShuffleId();
+ ShuffleAssignmentsInfo shuffleAssignmentsInfo;
+ try {
+ synchronized (TezRemoteShuffleUmbilicalProtocolImpl.class) {
+ if (shuffleIdToShuffleAssignsInfo.containsKey(shuffleId)) {
+ shuffleAssignmentsInfo = shuffleIdToShuffleAssignsInfo.get(shuffleId);
+ } else {
+ shuffleAssignmentsInfo = getShuffleWorks(request.getPartitionNum(), shuffleId);
+ }
+
+ if (shuffleAssignmentsInfo == null) {
+ response.setStatus(-1);
+ response.setRetMsg("shuffleAssignmentsInfo is null");
+ } else {
+ response.setStatus(0);
+ response.setRetMsg("");
+ response.setShuffleAssignmentsInfoWritable(new ShuffleAssignmentsInfoWritable(shuffleAssignmentsInfo));
+ shuffleIdToShuffleAssignsInfo.put(shuffleId, shuffleAssignmentsInfo);
+ }
+ }
+ } catch (Exception rssException) {
+ response.setStatus(-2);
+ response.setRetMsg(rssException.getMessage());
+ }
+
+ return response;
+ }
+ }
+
+ private ShuffleAssignmentsInfo getShuffleWorks(int partitionNum, int shuffleId) {
+ ShuffleAssignmentsInfo shuffleAssignmentsInfo;
+ int requiredAssignmentShuffleServersNum = RssTezUtils.getRequiredShuffleServerNumber(conf, 200, partitionNum);
+ // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
+ long retryInterval = conf.getLong(RssTezConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL,
+ RssTezConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE);
+ int retryTimes = conf.getInt(RssTezConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES,
+ RssTezConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE);
+
+ // Get the configured server assignment tags and it will also add default shuffle version tag.
+ Set<String> assignmentTags = new HashSet<>();
+ String rawTags = conf.get(RssTezConfig.RSS_CLIENT_ASSIGNMENT_TAGS, "");
+ if (StringUtils.isNotEmpty(rawTags)) {
+ rawTags = rawTags.trim();
+ assignmentTags.addAll(Arrays.asList(rawTags.split(",")));
+ }
+ assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
+
+
+ // get remote storage from coordinator if necessary
+ boolean dynamicConfEnabled = conf.getBoolean(RssTezConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED,
+ RssTezConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE);
+ RemoteStorageInfo defaultRemoteStorage =
+ new RemoteStorageInfo(conf.get(RssTezConfig.RSS_REMOTE_STORAGE_PATH, ""));
+ String storageType = conf.get(RssTezConfig.RSS_STORAGE_TYPE);
+ boolean testMode = conf.getBoolean(RssTezConfig.RSS_TEST_MODE_ENABLE, false);
+ ClientUtils.validateTestModeConf(testMode, storageType);
+ RemoteStorageInfo remoteStorage = ClientUtils.fetchRemoteStorage(
+ appId, defaultRemoteStorage, dynamicConfEnabled, storageType, rssClient);
+
+ try {
+ shuffleAssignmentsInfo = RetryUtils.retry(() -> {
+ ShuffleAssignmentsInfo shuffleAssignments =
+ rssClient.getShuffleAssignments(
+ appId,
+ shuffleId,
+ partitionNum,
+ 1,
+ Sets.newHashSet(assignmentTags),
+ requiredAssignmentShuffleServersNum,
+ -1
+ );
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
+ shuffleAssignments.getServerToPartitionRanges();
+
+ if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
+ return null;
+ }
+ LOG.info("Start to register shuffle");
+ long start = System.currentTimeMillis();
+ serverToPartitionRanges.entrySet().forEach(entry -> rssClient.registerShuffle(
+ entry.getKey(),
+ appId,
+ shuffleId,
+ entry.getValue(),
+ remoteStorage,
+ ShuffleDataDistributionType.NORMAL,
+ RssTezConfig.toRssConf(conf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE)
+ ));
+ LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms");
+ return shuffleAssignments;
+ }, retryInterval, retryTimes);
+ } catch (Throwable throwable) {
+ LOG.error("registerShuffle failed!", throwable);
+ throw new RssException("registerShuffle failed!", throwable);
+ }
+
+ return shuffleAssignmentsInfo;
+ }
+
+ protected void startRpcServer() {
+ try {
+ String rssAmRpcBindAddress;
+ Integer rssAmRpcBindPort;
+ if (conf.getBoolean(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_DEBUG, false)) {
+ rssAmRpcBindAddress = conf.get(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_ADDRESS, "0.0.0.0");
+ rssAmRpcBindPort = conf.getInt(RssTezConfig.RSS_AM_SHUFFLE_MANAGER_PORT, 0);
+ } else {
+ rssAmRpcBindAddress = "0.0.0.0";
+ rssAmRpcBindPort = 0;
+ }
+
+ server = new RPC.Builder(conf)
+ .setProtocol(TezRemoteShuffleUmbilicalProtocol.class)
+ .setBindAddress(rssAmRpcBindAddress)
+ .setPort(rssAmRpcBindPort)
+ .setInstance(tezRemoteShuffleUmbilical)
+ .setNumHandlers(
+ conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
+ TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
+ .setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE)
+ .build();
+
+ // Enable service authorization?
+ if (conf.getBoolean(
+ CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
+ false)) {
+ refreshServiceAcls(conf, new TezAMPolicyProvider());
+ }
+
+ server.start();
+ InetSocketAddress serverBindAddress = NetUtils.getConnectAddress(server);
+ this.address = NetUtils.createSocketAddrForHost(
+ serverBindAddress.getAddress().getCanonicalHostName(),
+ serverBindAddress.getPort());
+ LOG.info("Instantiated TezRemoteShuffleManager RPC at " + this.address);
+ } catch (IOException e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+
+ private void refreshServiceAcls(Configuration configuration,
+ PolicyProvider policyProvider) {
+ this.server.refreshServiceAcl(configuration, policyProvider);
+ }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
index 34419bf..57bcb1a 100644
--- a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
@@ -49,9 +49,9 @@
partitionToServers.put(4, new ArrayList<>());
ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
- ShuffleServerInfo work2 = new ShuffleServerInfo("host1", 9999);
- ShuffleServerInfo work3 = new ShuffleServerInfo("host1", 9999);
- ShuffleServerInfo work4 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work2 = new ShuffleServerInfo("host2", 9999);
+ ShuffleServerInfo work3 = new ShuffleServerInfo("host3", 9999);
+ ShuffleServerInfo work4 = new ShuffleServerInfo("host4", 9999);
partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3, work4));
partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3, work4));
diff --git a/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
new file mode 100644
index 0000000..997ad78
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/RssDAGAppMasterTest.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.tez.dag.api.TezException;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RssDAGAppMasterTest {
+
+ @Test
+ public void testAddAdditionalResource() throws TezException {
+ DAGProtos.DAGPlan dagPlan = DAGProtos.DAGPlan.getDefaultInstance();
+ List<DAGProtos.PlanLocalResource> originalResources = dagPlan.getLocalResourceList();
+ if (originalResources == null) {
+ originalResources = new ArrayList<>();
+ } else {
+ originalResources = new ArrayList<>(originalResources);
+ }
+
+ DAGProtos.PlanLocalResource additionalResource = DAGProtos.PlanLocalResource.newBuilder()
+ .setName("rss_conf.xml")
+ .setUri("/data1/test")
+ .setSize(12)
+ .setTimeStamp(System.currentTimeMillis())
+ .setType(DAGProtos.PlanLocalResourceType.FILE)
+ .setVisibility(DAGProtos.PlanLocalResourceVisibility.APPLICATION)
+ .build();
+
+ RssDAGAppMaster.addAdditionalResource(dagPlan, additionalResource);
+ List<DAGProtos.PlanLocalResource> newResources = dagPlan.getLocalResourceList();
+
+ originalResources.add(additionalResource);
+
+ assertEquals(originalResources.size(), newResources.size());
+ for (int i = 0; i < originalResources.size(); i++) {
+ assertEquals(originalResources.get(i), newResources.get(i));
+ }
+ }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
new file mode 100644
index 0000000..8d304e4
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app;
+
+import java.net.InetSocketAddress;
+import java.security.PrivilegedExceptionAction;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.GetShuffleServerRequest;
+import org.apache.tez.common.GetShuffleServerResponse;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TezRemoteShuffleManagerTest {
+
+ @Test
+ public void testTezRemoteShuffleManager() {
+ try {
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ partitionToServers.put(0, new ArrayList<>());
+ partitionToServers.put(1, new ArrayList<>());
+ partitionToServers.put(2, new ArrayList<>());
+ partitionToServers.put(3, new ArrayList<>());
+ partitionToServers.put(4, new ArrayList<>());
+
+ ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work2 = new ShuffleServerInfo("host2", 9999);
+ ShuffleServerInfo work3 = new ShuffleServerInfo("host3", 9999);
+ ShuffleServerInfo work4 = new ShuffleServerInfo("host4", 9999);
+
+ partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3, work4));
+ partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3, work4));
+ partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+ partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+ partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<>();
+ PartitionRange range0 = new PartitionRange(0, 0);
+ PartitionRange range1 = new PartitionRange(1, 1);
+ PartitionRange range2 = new PartitionRange(2, 2);
+ PartitionRange range3 = new PartitionRange(3, 3);
+ PartitionRange range4 = new PartitionRange(4, 4);
+
+ serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, range2));
+ serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, range4));
+ serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2, range3));
+ serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3, range4));
+
+ ShuffleAssignmentsInfo shuffleAssignmentsInfo = new ShuffleAssignmentsInfo(partitionToServers,
+ serverToPartitionRanges);
+
+ ShuffleWriteClient client = mock(ShuffleWriteClient.class);
+ when(client.getShuffleAssignments(anyString(), anyInt(), anyInt(), anyInt(), anySet(), anyInt(), anyInt()))
+ .thenReturn(shuffleAssignmentsInfo);
+
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+
+ Configuration conf = new Configuration();
+ TezRemoteShuffleManager tezRemoteShuffleManager = new TezRemoteShuffleManager(appId.toString(), null,
+ conf, appId.toString(), client);
+ tezRemoteShuffleManager.initialize();
+ tezRemoteShuffleManager.start();
+
+ String host = tezRemoteShuffleManager.address.getHostString();
+ int port = tezRemoteShuffleManager.address.getPort();
+ final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, port);
+
+ String tokenIdentifier = appId.toString();
+ UserGroupInformation taskOwner = UserGroupInformation.createRemoteUser(tokenIdentifier);
+
+ TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner.doAs(
+ new PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
+ @Override
+ public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
+ return RPC.getProxy(TezRemoteShuffleUmbilicalProtocol.class,
+ TezRemoteShuffleUmbilicalProtocol.versionID, address, conf);
+ }
+ });
+
+ TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+ TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+ TezTaskID tId = TezTaskID.getInstance(vId, 389);
+ TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, 2);
+
+ int mapNum = 1;
+ int shuffleId = 10001;
+ int reduceNum = shuffleAssignmentsInfo.getPartitionToServers().size();
+
+ String errorMessage = "failed to get Shuffle Assignments";
+ GetShuffleServerRequest request = new GetShuffleServerRequest(taId, mapNum, reduceNum, shuffleId);
+ GetShuffleServerResponse response = umbilical.getShuffleAssignments(request);
+ assertEquals(0, response.getStatus(), errorMessage);
+ assertEquals(reduceNum, response.getShuffleAssignmentsInfoWritable().getShuffleAssignmentsInfo()
+ .getPartitionToServers().size());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ assertEquals("test", e.getMessage());
+ fail();
+ }
+ }
+}