[#854][FOLLOWUP] feat(tez): add RssTezFetcher to fetch data from worker. (#920)
### What changes were proposed in this pull request?
add RssTezFetcher to fetch data from worker.
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/854
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test
diff --git a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
index 5b1a06f..152a013 100644
--- a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
@@ -68,10 +68,10 @@
* @throws TezException
*/
private static Map<Integer, List<ShuffleServerInfo>> doRequestShuffleServer(
- ApplicationId applicationId,
- Configuration conf,
- TezTaskAttemptID taskAttemptId,
- int shuffleId) throws IOException, InterruptedException, TezException {
+ ApplicationId applicationId,
+ Configuration conf,
+ TezTaskAttemptID taskAttemptId,
+ int shuffleId) throws IOException, InterruptedException, TezException {
UserGroupInformation taskOwner = UserGroupInformation.createRemoteUser(applicationId.toString());
Pair<String, Integer> amHostPort = getAmHostPort();
@@ -97,10 +97,10 @@
}
public static Map<Integer, List<ShuffleServerInfo>> requestShuffleServer(
- ApplicationId applicationId,
- Configuration conf,
- TezTaskAttemptID taskAttemptId,
- int shuffleId) {
+ ApplicationId applicationId,
+ Configuration conf,
+ TezTaskAttemptID taskAttemptId,
+ int shuffleId) {
try {
return doRequestShuffleServer(applicationId, conf, taskAttemptId,shuffleId);
} catch (IOException | InterruptedException | TezException e) {
diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
new file mode 100644
index 0000000..d7ff1b9
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.FetcherCallback;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RssTezBypassWriter;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+
+public class RssTezFetcher {
+ private static final Logger LOG = LoggerFactory.getLogger(RssTezFetcher.class);
+ private final FetcherCallback fetcherCallback;
+
+ private final FetchedInputAllocator inputManager;
+
+ private long copyBlockCount = 0;
+
+ private volatile boolean stopped = false;
+
+ private final ShuffleReadClient shuffleReadClient;
+ Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap;
+ private int partitionId; // partition id to fetch for this task
+ private long readTime = 0;
+ private long decompressTime = 0;
+ private long serializeTime = 0;
+ private long waitTime = 0;
+ private long copyTime = 0; // the sum of readTime + decompressTime + serializeTime + waitTime
+ private long unCompressionLength = 0;
+ private int uniqueMapId = 0;
+
+ private boolean hasPendingData = false;
+ private long startWait;
+ private int waitCount = 0;
+ private byte[] uncompressedData = null;
+ private Codec codec;
+
+ RssTezFetcher(FetcherCallback fetcherCallback,
+ FetchedInputAllocator inputManager,
+ ShuffleReadClient shuffleReadClient,
+ Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
+ int partitionId, RssConf rssConf) {
+ this.fetcherCallback = fetcherCallback;
+ this.inputManager = inputManager;
+ this.shuffleReadClient = shuffleReadClient;
+ this.rssSuccessBlockIdBitmapMap = rssSuccessBlockIdBitmapMap;
+ this.partitionId = partitionId;
+ this.codec = Codec.newInstance(rssConf);
+ }
+
+ public void fetchAllRssBlocks() throws IOException {
+ while (!stopped) {
+ try {
+ copyFromRssServer();
+ } catch (Exception e) {
+ LOG.error("Failed to fetchAllRssBlocks.", e);
+ throw e;
+ }
+ }
+ }
+
+ @VisibleForTesting
+ public void copyFromRssServer() throws IOException {
+ CompressedShuffleBlock compressedBlock = null;
+ ByteBuffer compressedData = null;
+ long blockStartFetch = 0;
+ // fetch a block
+ if (!hasPendingData) {
+ final long startFetch = System.currentTimeMillis();
+ blockStartFetch = System.currentTimeMillis();
+ compressedBlock = shuffleReadClient.readShuffleBlockData();
+ if (compressedBlock != null) {
+ compressedData = compressedBlock.getByteBuffer();
+ }
+ long fetchDuration = System.currentTimeMillis() - startFetch;
+ readTime += fetchDuration;
+ }
+
+ // uncompress the block
+ if (!hasPendingData && compressedData != null) {
+ final long startDecompress = System.currentTimeMillis();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
+ unCompressionLength += compressedBlock.getUncompressLength();
+ long decompressDuration = System.currentTimeMillis() - startDecompress;
+ decompressTime += decompressDuration;
+ }
+
+ if (uncompressedData != null) {
+ // start to merge
+ final long startSerialization = System.currentTimeMillis();
+ int compressedDataLength = compressedData != null ? compressedData.capacity() : 0;
+ if (issueMapOutputMerge(compressedDataLength, blockStartFetch)) {
+ long serializationDuration = System.currentTimeMillis() - startSerialization;
+ serializeTime += serializationDuration;
+ // if reserve successes, reset status for next fetch
+ if (hasPendingData) {
+ waitTime += System.currentTimeMillis() - startWait;
+ }
+ hasPendingData = false;
+ uncompressedData = null;
+ } else {
+ LOG.info("uncompressedData is null");
+ // if reserve fail, return and wait
+ waitCount++;
+ startWait = System.currentTimeMillis();
+ return;
+ }
+
+ // update some status
+ copyBlockCount++;
+ copyTime = readTime + decompressTime + serializeTime + waitTime;
+ } else {
+ LOG.info("uncompressedData is null");
+ // finish reading data, close related reader and check data consistent
+ shuffleReadClient.close();
+ shuffleReadClient.checkProcessedBlockIds();
+ shuffleReadClient.logStatics();
+ LOG.info("reduce task partition:" + partitionId + " read block cnt: " + copyBlockCount + " cost "
+ + readTime + " ms to fetch and " + decompressTime + " ms to decompress with unCompressionLength["
+ + unCompressionLength + "] and " + serializeTime + " ms to serialize and "
+ + waitTime + " ms to wait resource, total copy time: " + copyTime);
+ LOG.info("Stop fetcher");
+ stopFetch();
+ }
+ }
+
+ private boolean issueMapOutputMerge(int compressedLength, long blockStartFetch) throws IOException {
+ // Allocate a MapOutput (either in-memory or on-disk) to put uncompressed block
+ // In Rss, a MapOutput is sent as multiple blocks, so the reducer needs to
+ // treat each "block" as a faked "mapout".
+ // To avoid name conflicts, we use getNextUniqueTaskAttemptID instead.
+ // It will generate a unique TaskAttemptID(increased_seq++, 0).
+ InputAttemptIdentifier uniqueInputAttemptIdentifier = getNextUniqueInputAttemptIdentifier();
+ FetchedInput fetchedInput = null;
+
+ try {
+ fetchedInput = inputManager.allocate(uncompressedData.length,
+ compressedLength, uniqueInputAttemptIdentifier);
+ } catch (IOException ioe) {
+ // kill this reduce attempt
+ throw ioe;
+ }
+
+ // Allocated space and then write data to mapOutput
+ try {
+ RssTezBypassWriter.write(fetchedInput, uncompressedData);
+ // let the merger knows this block is ready for merging
+ fetcherCallback.fetchSucceeded(null, uniqueInputAttemptIdentifier, fetchedInput,
+ compressedLength, unCompressionLength, System.currentTimeMillis() - blockStartFetch);
+ } catch (Throwable t) {
+ LOG.error("Failed to write fetchedInput.", t);
+ throw new RssException("Partition: " + partitionId + " cannot write block to "
+ + fetchedInput.getClass().getSimpleName() + " due to: " + t.getClass().getName());
+ }
+ return true;
+ }
+
+ private InputAttemptIdentifier getNextUniqueInputAttemptIdentifier() {
+ return new InputAttemptIdentifier(uniqueMapId++, 0);
+ }
+
+ private void stopFetch() {
+ stopped = true;
+ }
+
+ @VisibleForTesting
+ public int getRetryCount() {
+ return waitCount;
+ }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java
new file mode 100644
index 0000000..4d00a5a
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.serializer.Deserializer;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetcherCallback;
+import org.apache.tez.runtime.library.common.shuffle.MemoryFetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryReader;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryWriter;
+import org.apache.tez.runtime.library.common.sort.impl.IFile;
+import org.apache.tez.runtime.library.utils.BufferUtils;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class RssTezFetcherTest {
+ private static final Logger LOG = LoggerFactory.getLogger(RssTezFetcherTest.class);
+
+ static Configuration conf = new Configuration();
+
+ static FileSystem fs;
+ static List<byte[]> data;
+ static List<KVPair> textData;
+ static Codec codec = new Lz4Codec();
+
+ @Test
+ public void writeAndReadDataTestWithoutRss() throws Throwable {
+ fs = FileSystem.getLocal(conf);
+ initRssData();
+ ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
+
+ SimpleFetchedInputAllocator inputManager = new SimpleFetchedInputAllocator(
+ TezUtilsInternal.cleanVertexName("Map 1"),
+ "1",
+ 2, conf,
+ 200 * 1024 * 1024L,
+ 300 * 1024 * 1024L);
+
+ List<byte[]> result = new ArrayList<byte[]>();
+ FetcherCallback fetcherCallback = new FetcherCallback() {
+ @Override
+ public void fetchSucceeded(String host, InputAttemptIdentifier srcAttemptIdentifier, FetchedInput fetchedInput,
+ long fetchedBytes, long decompressedLength, long copyDuration) throws IOException {
+ LOG.info("Fetch success");
+ fetchedInput.commit();
+ result.add(((MemoryFetchedInput) fetchedInput).getBytes());
+ }
+
+ @Override
+ public void fetchFailed(String host, InputAttemptIdentifier srcAttemptIdentifier, boolean connectFailed) {
+ fail();
+ }
+ };
+
+ RssTezFetcher rssFetcher = new RssTezFetcher(fetcherCallback,
+ inputManager,
+ shuffleReadClient,
+ null,
+ 2,
+ new RssConf());
+ rssFetcher.fetchAllRssBlocks();
+
+ for (int i = 0; i < data.size(); i++) {
+ Text readKey = new Text();
+ IntWritable readValue = new IntWritable();
+ Deserializer<Text> keyDeserializer;
+ Deserializer<IntWritable> valDeserializer;
+ SerializationFactory serializationFactory = new SerializationFactory(conf);
+ keyDeserializer = serializationFactory.getDeserializer(Text.class);
+ valDeserializer = serializationFactory.getDeserializer(IntWritable.class);
+ DataInputBuffer keyIn = new DataInputBuffer();
+ DataInputBuffer valIn = new DataInputBuffer();
+ keyDeserializer.open(keyIn);
+ valDeserializer.open(valIn);
+
+ InMemoryReader reader = new InMemoryReader(null,
+ new InputAttemptIdentifier(0, 0), result.get(i), 0, result.get(i).length);
+ int numRecordsRead = 0;
+ while (reader.nextRawKey(keyIn)) {
+ reader.nextRawValue(valIn);
+ readKey = keyDeserializer.deserialize(readKey);
+ readValue = valDeserializer.deserialize(readValue);
+
+ KVPair expected = textData.get(numRecordsRead);
+ assertEquals(expected.getKey(), readKey);
+ assertEquals(expected.getvalue(), readValue);
+
+ numRecordsRead++;
+ }
+ assertEquals(textData.size(), numRecordsRead);
+ LOG.info("Found: " + numRecordsRead + " records");
+ }
+ }
+
+ private static void initRssData() throws Exception {
+ InMemoryWriter writer = null;
+ BoundedByteArrayOutputStream bout = new BoundedByteArrayOutputStream(1024 * 1024);
+ List<KVPair> pairData = generateTestData(true, 10);
+ //No RLE, No RepeatKeys, no compression
+ writer = new InMemoryWriter(bout);
+ writeTestFileUsingDataBuffer(writer, false, pairData);
+ data = new ArrayList<>();
+ data.add(bout.getBuffer());
+ textData = new ArrayList<>();
+ textData.addAll(pairData);
+ }
+
+ static class MockedShuffleReadClient implements ShuffleReadClient {
+ List<CompressedShuffleBlock> blocks;
+ int index = 0;
+
+ MockedShuffleReadClient(List<byte[]> data) {
+ this.blocks = new LinkedList<>();
+ data.forEach(bytes -> {
+ byte[] compressed = codec.compress(bytes);
+ blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), bytes.length));
+ });
+ }
+
+ @Override
+ public CompressedShuffleBlock readShuffleBlockData() {
+ if (index < blocks.size()) {
+ return blocks.get(index++);
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public void checkProcessedBlockIds() {
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public void logStatics() {
+ }
+ }
+
+
+ /**
+ * Generate key value pair
+ *
+ * @param sorted whether data should be sorted by key
+ * @param repeatCount number of keys to be repeated
+ * @return
+ */
+ public static List<KVPair> generateTestData(boolean sorted, int repeatCount) {
+ List<KVPair> data = new LinkedList<KVPair>();
+ Random rnd = new Random();
+ KVPair kvp = null;
+ for (int i = 0; i < 5; i++) {
+ String keyStr = (sorted) ? ("key" + i) : (rnd.nextLong() + "key" + i);
+ Text key = new Text(keyStr);
+ IntWritable value = new IntWritable(i + repeatCount);
+ kvp = new KVPair(key, value);
+ data.add(kvp);
+ if ((repeatCount > 0) && (i % 2 == 0)) { // Repeat this key for random number of times
+ int count = rnd.nextInt(5);
+ for (int j = 0; j < count; j++) {
+ repeatCount++;
+ value.set(i + rnd.nextInt());
+ kvp = new KVPair(key, value);
+ data.add(kvp);
+ }
+ }
+ }
+ //If we need to generated repeated keys, try to add some repeated keys to the end of file also.
+ if (repeatCount > 0 && kvp != null) {
+ data.add(kvp);
+ data.add(kvp);
+ }
+ return data;
+ }
+
+ private static IFile.Writer writeTestFileUsingDataBuffer(IFile.Writer writer, boolean repeatKeys,
+ List<KVPair> data) throws IOException {
+ DataInputBuffer previousKey = new DataInputBuffer();
+ DataInputBuffer key = new DataInputBuffer();
+ DataInputBuffer value = new DataInputBuffer();
+ for (KVPair kvp : data) {
+ populateData(kvp, key, value);
+
+ if (repeatKeys && (previousKey != null && BufferUtils.compare(key, previousKey) == 0)) {
+ writer.append(org.apache.tez.runtime.library.common.sort.impl.IFile.REPEAT_KEY, value);
+ } else {
+ writer.append(key, value);
+ }
+ previousKey.reset(key.getData(), 0, key.getLength());
+ }
+ writer.close();
+ LOG.info("Uncompressed: " + writer.getRawLength());
+ LOG.info("CompressedSize: " + writer.getCompressedLength());
+ return writer;
+ }
+
+ private static void populateData(KVPair kvp, DataInputBuffer key, DataInputBuffer value)
+ throws IOException {
+ DataOutputBuffer k = new DataOutputBuffer();
+ DataOutputBuffer v = new DataOutputBuffer();
+ kvp.getKey().write(k);
+ kvp.getvalue().write(v);
+ key.reset(k.getData(), 0, k.getLength());
+ value.reset(v.getData(), 0, v.getLength());
+ }
+
+ public static class KVPair {
+ private Text key;
+ private IntWritable value;
+
+ public KVPair(Text key, IntWritable value) {
+ this.key = key;
+ this.value = value;
+ }
+
+ public Text getKey() {
+ return this.key;
+ }
+
+ public IntWritable getvalue() {
+ return this.value;
+ }
+ }
+}