TEZ-3709. TezMerger is slow for high number of segments (jeagles)

(cherry picked from commit 4d100b2bfb880927932ff095f2ba02780d5df01a)
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/MergeManager.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/MergeManager.java
index 26bdca7..7321397 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/MergeManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/MergeManager.java
@@ -1005,8 +1005,9 @@
     for (MapOutput mo : inMemoryMapOutputs) {
       fullSize += mo.getMemory().length;
     }
+    int inMemoryMapOutputsOffset = 0;
     while((fullSize > leaveBytes) && !Thread.currentThread().isInterrupted()) {
-      MapOutput mo = inMemoryMapOutputs.remove(0);
+      MapOutput mo = inMemoryMapOutputs.get(inMemoryMapOutputsOffset++);
       byte[] data = mo.getMemory();
       long size = data.length;
       totalSize += size;
@@ -1018,6 +1019,8 @@
                                             (mo.isPrimaryMapOutput() ? 
                                             mergedMapOutputsCounter : null)));
     }
+    // Bulk remove removed in-memory map outputs efficiently
+    inMemoryMapOutputs.subList(0, inMemoryMapOutputsOffset).clear();
     return totalSize;
   }
 
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/TezMerger.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/TezMerger.java
index 17e0fe2..23f2946 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/TezMerger.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/TezMerger.java
@@ -19,6 +19,7 @@
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
@@ -721,7 +722,7 @@
         mergeProgress = mergePhase;
       }
 
-      long totalBytes = computeBytesInMerges(factor, inMem);
+      long totalBytes = computeBytesInMerges(segments, factor, inMem, considerFinalMergeForProgress);
       if (totalBytes != 0) {
         progPerByte = 1.0f / (float)totalBytes;
       }
@@ -891,7 +892,7 @@
      * number of segments - 1 to be divisible by the factor - 1 (each pass
      * takes X segments and produces 1) to minimize the number of merges.
      */
-    private int getPassFactor(int factor, int passNo, int numSegments) {
+    private static int getPassFactor(int factor, int passNo, int numSegments) {
       // passNo > 1 in the OR list - is that correct ?
       if (passNo > 1 || numSegments <= factor || factor == 1) 
         return factor;
@@ -910,14 +911,12 @@
         segments.clear();
         return subList;
       }
-      
-      List<Segment> subList = 
-        new ArrayList<Segment>(segments.subList(0, numDescriptors));
-      // TODO Replace this with a batch operation
-      for (int i=0; i < numDescriptors; ++i) {
-        segments.remove(0);
-      }
-      return subList;
+
+      // Efficiently bulk remove segments
+      List<Segment> subList = segments.subList(0, numDescriptors);
+      List<Segment> subListCopy = new ArrayList<>(subList);
+      subList.clear();
+      return subListCopy;
     }
     
     /**
@@ -925,12 +924,14 @@
      * calculating mergeProgress. This simulates the above merge() method and
      * tries to obtain the number of bytes that are going to be merged in all
      * merges(assuming that there is no combiner called while merging).
+     * @param segments segments to compute merge bytes
      * @param factor mapreduce.task.io.sort.factor
      * @param inMem  number of segments in memory to be merged
+     * @param considerFinalMergeForProgress whether to consider for final merge
      */
-    long computeBytesInMerges(int factor, int inMem) {
+    static long computeBytesInMerges(List<Segment> segments, int factor, int inMem, boolean considerFinalMergeForProgress) {
       int numSegments = segments.size();
-      List<Long> segmentSizes = new ArrayList<Long>(numSegments);
+      long[] segmentSizes = new long[numSegments];
       long totalBytes = 0;
       int n = numSegments - inMem;
       // factor for 1st pass
@@ -940,33 +941,67 @@
       for (int i = 0; i < numSegments; i++) {
         // Not handling empty segments here assuming that it would not affect
         // much in calculation of mergeProgress.
-        segmentSizes.add(segments.get(i).getLength());
+        segmentSizes[i] = segments.get(i).getLength();
       }
       
       // If includeFinalMerge is true, allow the following while loop iterate
       // for 1 more iteration. This is to include final merge as part of the
       // computation of expected input bytes of merges
       boolean considerFinalMerge = considerFinalMergeForProgress;
-      
+
+      int offset = 0;
       while (n > f || considerFinalMerge) {
-        if (n <=f ) {
+        if (n <= f) {
           considerFinalMerge = false;
         }
         long mergedSize = 0;
-        f = Math.min(f, segmentSizes.size());
+        f = Math.min(f, n);
         for (int j = 0; j < f; j++) {
-          mergedSize += segmentSizes.remove(0);
+          mergedSize += segmentSizes[offset + j];
         }
         totalBytes += mergedSize;
         
         // insert new size into the sorted list
-        int pos = Collections.binarySearch(segmentSizes, mergedSize);
+        int pos = Arrays.binarySearch(segmentSizes, offset, offset + n, mergedSize);
         if (pos < 0) {
           pos = -pos-1;
         }
-        segmentSizes.add(pos, mergedSize);
-
-        n -= (f-1);
+        if (pos < offset + f) {
+          // Insert at the beginning
+          offset += f - 1;
+          segmentSizes[offset] = mergedSize;
+        } else if (pos < offset + n) {
+          // Insert in the middle
+          if (offset + n < segmentSizes.length) {
+            // Shift right after insertion point into unused capacity
+            System.arraycopy(segmentSizes, pos, segmentSizes, pos + 1, offset + n - pos);
+            // Insert into insertion point
+            segmentSizes[pos] = mergedSize;
+            offset += f;
+          } else {
+            // Full left shift before insertion point
+            System.arraycopy(segmentSizes, offset + f, segmentSizes, 0, pos - (offset + f));
+            // Insert in the middle
+            segmentSizes[pos - (offset + f)] = mergedSize;
+            // Full left shift after insertion point
+            System.arraycopy(segmentSizes, pos, segmentSizes, pos - (offset + f) + 1, offset + n - pos);
+            offset = 0;
+          }
+        } else {
+          // Insert at the end
+          if (pos < segmentSizes.length) {
+            // Append into unused capacity
+            segmentSizes[pos] = mergedSize;
+            offset += f;
+          } else {
+            // Full left shift
+            // Append at the end
+            System.arraycopy(segmentSizes, offset + f, segmentSizes, 0, n - f);
+            segmentSizes[n - f] = mergedSize;
+            offset = 0;
+          }
+        }
+        n -=  f - 1;
         f = factor;
       }