[#854][FOLLOWUP] feat(tez): add RssTezFetcher to fetch data from worker. (#920)

### What changes were proposed in this pull request?

add RssTezFetcher to fetch data from worker.

### Why are the changes needed?

Fix: https://github.com/apache/incubator-uniffle/issues/854

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?
unit test
diff --git a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
index 5b1a06f..152a013 100644
--- a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
@@ -68,10 +68,10 @@
    * @throws TezException
    */
   private static Map<Integer, List<ShuffleServerInfo>> doRequestShuffleServer(
-      ApplicationId applicationId,
-      Configuration conf,
-      TezTaskAttemptID taskAttemptId,
-      int shuffleId) throws IOException, InterruptedException, TezException {
+            ApplicationId applicationId,
+            Configuration conf,
+            TezTaskAttemptID taskAttemptId,
+            int shuffleId) throws IOException, InterruptedException, TezException {
     UserGroupInformation taskOwner = UserGroupInformation.createRemoteUser(applicationId.toString());
 
     Pair<String, Integer> amHostPort = getAmHostPort();
@@ -97,10 +97,10 @@
   }
 
   public static Map<Integer, List<ShuffleServerInfo>> requestShuffleServer(
-      ApplicationId applicationId,
-      Configuration conf,
-      TezTaskAttemptID taskAttemptId,
-      int shuffleId) {
+          ApplicationId applicationId,
+          Configuration conf,
+          TezTaskAttemptID taskAttemptId,
+          int shuffleId) {
     try {
       return doRequestShuffleServer(applicationId, conf, taskAttemptId,shuffleId);
     } catch (IOException | InterruptedException | TezException e) {
diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
new file mode 100644
index 0000000..d7ff1b9
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcher.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInputAllocator;
+import org.apache.tez.runtime.library.common.shuffle.FetcherCallback;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RssTezBypassWriter;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+
+public class RssTezFetcher {
+  private static final Logger LOG = LoggerFactory.getLogger(RssTezFetcher.class);
+  private final FetcherCallback fetcherCallback;
+
+  private final FetchedInputAllocator inputManager;
+
+  private long copyBlockCount = 0;
+
+  private volatile boolean stopped = false;
+
+  private final ShuffleReadClient shuffleReadClient;
+  Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap;
+  private int partitionId;   // partition id to fetch for this task
+  private long readTime = 0;
+  private long decompressTime = 0;
+  private long serializeTime = 0;
+  private long waitTime = 0;
+  private long copyTime = 0;  // the sum of readTime + decompressTime + serializeTime + waitTime
+  private long unCompressionLength = 0;
+  private int uniqueMapId = 0;
+
+  private boolean hasPendingData = false;
+  private long startWait;
+  private int waitCount = 0;
+  private byte[] uncompressedData = null;
+  private Codec codec;
+
+  RssTezFetcher(FetcherCallback fetcherCallback,
+                FetchedInputAllocator inputManager,
+                ShuffleReadClient shuffleReadClient,
+                Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
+                int partitionId, RssConf rssConf) {
+    this.fetcherCallback = fetcherCallback;
+    this.inputManager = inputManager;
+    this.shuffleReadClient = shuffleReadClient;
+    this.rssSuccessBlockIdBitmapMap = rssSuccessBlockIdBitmapMap;
+    this.partitionId = partitionId;
+    this.codec = Codec.newInstance(rssConf);
+  }
+
+  public void fetchAllRssBlocks() throws IOException {
+    while (!stopped) {
+      try {
+        copyFromRssServer();
+      } catch (Exception e) {
+        LOG.error("Failed to fetchAllRssBlocks.", e);
+        throw e;
+      }
+    }
+  }
+
+  @VisibleForTesting
+  public void copyFromRssServer() throws IOException {
+    CompressedShuffleBlock compressedBlock = null;
+    ByteBuffer compressedData = null;
+    long blockStartFetch = 0;
+    // fetch a block
+    if (!hasPendingData) {
+      final long startFetch = System.currentTimeMillis();
+      blockStartFetch = System.currentTimeMillis();
+      compressedBlock = shuffleReadClient.readShuffleBlockData();
+      if (compressedBlock != null) {
+        compressedData = compressedBlock.getByteBuffer();
+      }
+      long fetchDuration = System.currentTimeMillis() - startFetch;
+      readTime += fetchDuration;
+    }
+
+    // uncompress the block
+    if (!hasPendingData && compressedData != null) {
+      final long startDecompress = System.currentTimeMillis();
+      int uncompressedLen = compressedBlock.getUncompressLength();
+      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+      codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+      uncompressedData = decompressedBuffer.array();
+      unCompressionLength += compressedBlock.getUncompressLength();
+      long decompressDuration = System.currentTimeMillis() - startDecompress;
+      decompressTime += decompressDuration;
+    }
+
+    if (uncompressedData != null) {
+      // start to merge
+      final long startSerialization = System.currentTimeMillis();
+      int compressedDataLength = compressedData != null ? compressedData.capacity() : 0;
+      if (issueMapOutputMerge(compressedDataLength, blockStartFetch)) {
+        long serializationDuration = System.currentTimeMillis() - startSerialization;
+        serializeTime += serializationDuration;
+        // if reserve successes, reset status for next fetch
+        if (hasPendingData) {
+          waitTime += System.currentTimeMillis() - startWait;
+        }
+        hasPendingData = false;
+        uncompressedData = null;
+      } else {
+        LOG.info("uncompressedData is null");
+        // if reserve fail, return and wait
+        waitCount++;
+        startWait = System.currentTimeMillis();
+        return;
+      }
+
+      // update some status
+      copyBlockCount++;
+      copyTime = readTime + decompressTime + serializeTime + waitTime;
+    } else {
+      LOG.info("uncompressedData is null");
+      // finish reading data, close related reader and check data consistent
+      shuffleReadClient.close();
+      shuffleReadClient.checkProcessedBlockIds();
+      shuffleReadClient.logStatics();
+      LOG.info("reduce task partition:" + partitionId + " read block cnt: " + copyBlockCount + " cost "
+          + readTime + " ms to fetch and " + decompressTime + " ms to decompress with unCompressionLength["
+          + unCompressionLength + "] and " + serializeTime + " ms to serialize and "
+          + waitTime + " ms to wait resource, total copy time: " + copyTime);
+      LOG.info("Stop fetcher");
+      stopFetch();
+    }
+  }
+
+  private boolean issueMapOutputMerge(int compressedLength, long blockStartFetch) throws IOException {
+    // Allocate a MapOutput (either in-memory or on-disk) to put uncompressed block
+    // In Rss, a MapOutput is sent as multiple blocks, so the reducer needs to
+    // treat each "block" as a faked "mapout".
+    // To avoid name conflicts, we use getNextUniqueTaskAttemptID instead.
+    // It will generate a unique TaskAttemptID(increased_seq++, 0).
+    InputAttemptIdentifier uniqueInputAttemptIdentifier = getNextUniqueInputAttemptIdentifier();
+    FetchedInput fetchedInput = null;
+
+    try {
+      fetchedInput = inputManager.allocate(uncompressedData.length,
+          compressedLength, uniqueInputAttemptIdentifier);
+    } catch (IOException ioe) {
+      // kill this reduce attempt
+      throw ioe;
+    }
+
+    // Allocated space and then write data to mapOutput
+    try {
+      RssTezBypassWriter.write(fetchedInput, uncompressedData);
+      // let the merger knows this block is ready for merging
+      fetcherCallback.fetchSucceeded(null, uniqueInputAttemptIdentifier, fetchedInput,
+          compressedLength, unCompressionLength, System.currentTimeMillis() - blockStartFetch);
+    } catch (Throwable t) {
+      LOG.error("Failed to write fetchedInput.", t);
+      throw new RssException("Partition: " + partitionId + " cannot write block to "
+          + fetchedInput.getClass().getSimpleName() + " due to: " + t.getClass().getName());
+    }
+    return true;
+  }
+
+  private InputAttemptIdentifier getNextUniqueInputAttemptIdentifier() {
+    return new InputAttemptIdentifier(uniqueMapId++, 0);
+  }
+
+  private void stopFetch() {
+    stopped = true;
+  }
+
+  @VisibleForTesting
+  public int getRetryCount() {
+    return waitCount;
+  }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java
new file mode 100644
index 0000000..4d00a5a
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.serializer.Deserializer;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.tez.common.TezUtilsInternal;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.shuffle.FetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.FetcherCallback;
+import org.apache.tez.runtime.library.common.shuffle.MemoryFetchedInput;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryReader;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryWriter;
+import org.apache.tez.runtime.library.common.sort.impl.IFile;
+import org.apache.tez.runtime.library.utils.BufferUtils;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class RssTezFetcherTest {
+  private static final Logger LOG = LoggerFactory.getLogger(RssTezFetcherTest.class);
+
+  static Configuration conf = new Configuration();
+
+  static FileSystem fs;
+  static List<byte[]> data;
+  static List<KVPair> textData;
+  static Codec codec = new Lz4Codec();
+
+  @Test
+  public void writeAndReadDataTestWithoutRss() throws Throwable {
+    fs = FileSystem.getLocal(conf);
+    initRssData();
+    ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
+
+    SimpleFetchedInputAllocator inputManager = new SimpleFetchedInputAllocator(
+        TezUtilsInternal.cleanVertexName("Map 1"),
+        "1",
+        2, conf,
+        200 * 1024 * 1024L,
+        300 * 1024 * 1024L);
+
+    List<byte[]> result = new ArrayList<byte[]>();
+    FetcherCallback fetcherCallback = new FetcherCallback() {
+      @Override
+      public void fetchSucceeded(String host, InputAttemptIdentifier srcAttemptIdentifier, FetchedInput fetchedInput,
+            long fetchedBytes, long decompressedLength, long copyDuration) throws IOException {
+        LOG.info("Fetch success");
+        fetchedInput.commit();
+        result.add(((MemoryFetchedInput) fetchedInput).getBytes());
+      }
+
+      @Override
+      public void fetchFailed(String host, InputAttemptIdentifier srcAttemptIdentifier, boolean connectFailed) {
+        fail();
+      }
+    };
+
+    RssTezFetcher rssFetcher = new RssTezFetcher(fetcherCallback,
+        inputManager,
+        shuffleReadClient,
+        null,
+        2,
+        new RssConf());
+    rssFetcher.fetchAllRssBlocks();
+
+    for (int i = 0; i < data.size(); i++) {
+      Text readKey = new Text();
+      IntWritable readValue = new IntWritable();
+      Deserializer<Text> keyDeserializer;
+      Deserializer<IntWritable> valDeserializer;
+      SerializationFactory serializationFactory = new SerializationFactory(conf);
+      keyDeserializer = serializationFactory.getDeserializer(Text.class);
+      valDeserializer = serializationFactory.getDeserializer(IntWritable.class);
+      DataInputBuffer keyIn = new DataInputBuffer();
+      DataInputBuffer valIn = new DataInputBuffer();
+      keyDeserializer.open(keyIn);
+      valDeserializer.open(valIn);
+
+      InMemoryReader reader = new InMemoryReader(null,
+          new InputAttemptIdentifier(0, 0), result.get(i), 0, result.get(i).length);
+      int numRecordsRead = 0;
+      while (reader.nextRawKey(keyIn)) {
+        reader.nextRawValue(valIn);
+        readKey = keyDeserializer.deserialize(readKey);
+        readValue = valDeserializer.deserialize(readValue);
+
+        KVPair expected = textData.get(numRecordsRead);
+        assertEquals(expected.getKey(), readKey);
+        assertEquals(expected.getvalue(), readValue);
+
+        numRecordsRead++;
+      }
+      assertEquals(textData.size(), numRecordsRead);
+      LOG.info("Found: " + numRecordsRead + " records");
+    }
+  }
+
+  private static void initRssData() throws Exception {
+    InMemoryWriter writer = null;
+    BoundedByteArrayOutputStream bout = new BoundedByteArrayOutputStream(1024 * 1024);
+    List<KVPair> pairData = generateTestData(true, 10);
+    //No RLE, No RepeatKeys, no compression
+    writer = new InMemoryWriter(bout);
+    writeTestFileUsingDataBuffer(writer, false, pairData);
+    data = new ArrayList<>();
+    data.add(bout.getBuffer());
+    textData = new ArrayList<>();
+    textData.addAll(pairData);
+  }
+
+  static class MockedShuffleReadClient implements ShuffleReadClient {
+    List<CompressedShuffleBlock> blocks;
+    int index = 0;
+
+    MockedShuffleReadClient(List<byte[]> data) {
+      this.blocks = new LinkedList<>();
+      data.forEach(bytes -> {
+        byte[] compressed = codec.compress(bytes);
+        blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), bytes.length));
+      });
+    }
+
+    @Override
+    public CompressedShuffleBlock readShuffleBlockData() {
+      if (index < blocks.size()) {
+        return blocks.get(index++);
+      } else {
+        return null;
+      }
+    }
+
+    @Override
+    public void checkProcessedBlockIds() {
+    }
+
+    @Override
+    public void close() {
+    }
+
+    @Override
+    public void logStatics() {
+    }
+  }
+
+
+  /**
+   * Generate key value pair
+   *
+   * @param sorted whether data should be sorted by key
+   * @param repeatCount number of keys to be repeated
+   * @return
+   */
+  public static List<KVPair> generateTestData(boolean sorted, int repeatCount) {
+    List<KVPair> data = new LinkedList<KVPair>();
+    Random rnd = new Random();
+    KVPair kvp = null;
+    for (int i = 0; i < 5; i++) {
+      String keyStr = (sorted) ? ("key" + i) : (rnd.nextLong() + "key" + i);
+      Text key = new Text(keyStr);
+      IntWritable value = new IntWritable(i + repeatCount);
+      kvp = new KVPair(key, value);
+      data.add(kvp);
+      if ((repeatCount > 0) && (i % 2 == 0)) { // Repeat this key for random number of times
+        int count = rnd.nextInt(5);
+        for (int j = 0; j < count; j++) {
+          repeatCount++;
+          value.set(i + rnd.nextInt());
+          kvp = new KVPair(key, value);
+          data.add(kvp);
+        }
+      }
+    }
+    //If we need to generated repeated keys, try to add some repeated keys to the end of file also.
+    if (repeatCount > 0 && kvp != null) {
+      data.add(kvp);
+      data.add(kvp);
+    }
+    return data;
+  }
+
+  private static IFile.Writer writeTestFileUsingDataBuffer(IFile.Writer writer, boolean repeatKeys,
+            List<KVPair> data) throws IOException {
+    DataInputBuffer previousKey = new DataInputBuffer();
+    DataInputBuffer key = new DataInputBuffer();
+    DataInputBuffer value = new DataInputBuffer();
+    for (KVPair kvp : data) {
+      populateData(kvp, key, value);
+
+      if (repeatKeys && (previousKey != null && BufferUtils.compare(key, previousKey) == 0)) {
+        writer.append(org.apache.tez.runtime.library.common.sort.impl.IFile.REPEAT_KEY, value);
+      } else {
+        writer.append(key, value);
+      }
+      previousKey.reset(key.getData(), 0, key.getLength());
+    }
+    writer.close();
+    LOG.info("Uncompressed: " + writer.getRawLength());
+    LOG.info("CompressedSize: " + writer.getCompressedLength());
+    return writer;
+  }
+
+  private static void populateData(KVPair kvp, DataInputBuffer key, DataInputBuffer value)
+      throws  IOException {
+    DataOutputBuffer k = new DataOutputBuffer();
+    DataOutputBuffer v = new DataOutputBuffer();
+    kvp.getKey().write(k);
+    kvp.getvalue().write(v);
+    key.reset(k.getData(), 0, k.getLength());
+    value.reset(v.getData(), 0, v.getLength());
+  }
+
+  public static class KVPair {
+    private Text key;
+    private IntWritable value;
+
+    public KVPair(Text key, IntWritable value) {
+      this.key = key;
+      this.value = value;
+    }
+
+    public Text getKey() {
+      return this.key;
+    }
+
+    public IntWritable getvalue() {
+      return this.value;
+    }
+  }
+}