Integrate merge-time index reordering with the intra-merge executor. (#13289)

Index reordering can benefit greatly from parallelism, so it should try to use
the intra-merge executor when possible.
diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
index 7d1a57b..def9ef0 100644
--- a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
@@ -37,6 +37,7 @@
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.Executor;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -3434,9 +3435,11 @@
                 .map(FieldInfos::getParentField)
                 .anyMatch(Objects::isNull);
 
+    final Executor intraMergeExecutor = mergeScheduler.getIntraMergeExecutor(merge);
+
     if (hasIndexSort == false && hasBlocksButNoParentField == false && readers.isEmpty() == false) {
       CodecReader mergedReader = SlowCompositeCodecReaderWrapper.wrap(readers);
-      DocMap docMap = merge.reorder(mergedReader, directory);
+      DocMap docMap = merge.reorder(mergedReader, directory, intraMergeExecutor);
       if (docMap != null) {
         readers = Collections.singletonList(SortingCodecReader.wrap(mergedReader, docMap, null));
       }
@@ -3450,7 +3453,7 @@
             trackingDir,
             globalFieldNumberMap,
             context,
-            mergeScheduler.getIntraMergeExecutor(merge));
+            intraMergeExecutor);
 
     if (!merger.shouldMerge()) {
       return;
@@ -3928,9 +3931,9 @@
                       }
 
                       @Override
-                      public Sorter.DocMap reorder(CodecReader reader, Directory dir)
-                          throws IOException {
-                        return toWrap.reorder(reader, dir); // must delegate
+                      public Sorter.DocMap reorder(
+                          CodecReader reader, Directory dir, Executor executor) throws IOException {
+                        return toWrap.reorder(reader, dir, executor); // must delegate
                       }
 
                       @Override
@@ -5205,6 +5208,8 @@
         mergeReaders.add(wrappedReader);
       }
 
+      final Executor intraMergeExecutor = mergeScheduler.getIntraMergeExecutor(merge);
+
       MergeState.DocMap[] reorderDocMaps = null;
       // Don't reorder if an explicit sort is configured.
       final boolean hasIndexSort = config.getIndexSort() != null;
@@ -5219,7 +5224,7 @@
       if (hasIndexSort == false && hasBlocksButNoParentField == false) {
         // Create a merged view of the input segments. This effectively does the merge.
         CodecReader mergedView = SlowCompositeCodecReaderWrapper.wrap(mergeReaders);
-        Sorter.DocMap docMap = merge.reorder(mergedView, directory);
+        Sorter.DocMap docMap = merge.reorder(mergedView, directory, intraMergeExecutor);
         if (docMap != null) {
           reorderDocMaps = new MergeState.DocMap[mergeReaders.size()];
           int docBase = 0;
@@ -5249,7 +5254,7 @@
               dirWrapper,
               globalFieldNumberMap,
               context,
-              mergeScheduler.getIntraMergeExecutor(merge));
+              intraMergeExecutor);
       merge.info.setSoftDelCount(Math.toIntExact(softDeleteCount.get()));
       merge.checkAborted();
 
diff --git a/lucene/core/src/java/org/apache/lucene/index/MergePolicy.java b/lucene/core/src/java/org/apache/lucene/index/MergePolicy.java
index 2eaab9c..d66f564 100644
--- a/lucene/core/src/java/org/apache/lucene/index/MergePolicy.java
+++ b/lucene/core/src/java/org/apache/lucene/index/MergePolicy.java
@@ -27,6 +27,7 @@
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
@@ -292,7 +293,7 @@
      * Wrap a reader prior to merging in order to add/remove fields or documents.
      *
      * <p><b>NOTE:</b> It is illegal to reorder doc IDs here, use {@link
-     * #reorder(CodecReader,Directory)} instead.
+     * #reorder(CodecReader,Directory,Executor)} instead.
      */
     public CodecReader wrapForMerge(CodecReader reader) throws IOException {
       return reader;
@@ -308,9 +309,12 @@
      *
      * @param reader The reader to reorder.
      * @param dir The {@link Directory} of the index, which may be used to create temporary files.
+     * @param executor An executor that can be used to parallelize the reordering logic. May be
+     *     {@code null} if no concurrency is supported.
      * @lucene.experimental
      */
-    public Sorter.DocMap reorder(CodecReader reader, Directory dir) throws IOException {
+    public Sorter.DocMap reorder(CodecReader reader, Directory dir, Executor executor)
+        throws IOException {
       return null;
     }
 
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
index 43b6c7b..f51321f 100644
--- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
@@ -22,8 +22,8 @@
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Set;
-import java.util.concurrent.ForkJoinPool;
-import java.util.concurrent.RecursiveAction;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executor;
 import org.apache.lucene.codecs.CodecUtil;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DocValues;
@@ -37,6 +37,7 @@
 import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TaskExecutor;
 import org.apache.lucene.store.ByteBuffersDataOutput;
 import org.apache.lucene.store.DataInput;
 import org.apache.lucene.store.DataOutput;
@@ -123,7 +124,6 @@
   private float maxDocFreq;
   private int minPartitionSize;
   private int maxIters;
-  private ForkJoinPool forkJoinPool;
   private double ramBudgetMB;
   private Set<String> fields;
 
@@ -133,7 +133,6 @@
     setMaxDocFreq(1f);
     setMinPartitionSize(DEFAULT_MIN_PARTITION_SIZE);
     setMaxIters(DEFAULT_MAX_ITERS);
-    setForkJoinPool(null);
     // 10% of the available heap size by default
     setRAMBudgetMB(Runtime.getRuntime().totalMemory() / 1024d / 1024d / 10d);
     setFields(null);
@@ -182,20 +181,6 @@
   }
 
   /**
-   * Set the {@link ForkJoinPool} to run graph partitioning concurrently.
-   *
-   * <p>NOTE: A value of {@code null} can be used to run in the current thread, which is the
-   * default.
-   */
-  public void setForkJoinPool(ForkJoinPool forkJoinPool) {
-    this.forkJoinPool = forkJoinPool;
-  }
-
-  private int getParallelism() {
-    return forkJoinPool == null ? 1 : forkJoinPool.getParallelism();
-  }
-
-  /**
    * Set the amount of RAM that graph partitioning is allowed to use. More RAM allows running
    * faster. If not enough RAM is provided, a {@link NotEnoughRAMException} will be thrown. This is
    * 10% of the total heap size by default.
@@ -225,21 +210,18 @@
     }
   }
 
-  private abstract class BaseRecursiveAction extends RecursiveAction {
+  private abstract class BaseRecursiveAction implements Callable<Void> {
 
+    protected final TaskExecutor executor;
     protected final int depth;
 
-    BaseRecursiveAction(int depth) {
+    BaseRecursiveAction(TaskExecutor executor, int depth) {
+      this.executor = executor;
       this.depth = depth;
     }
 
     protected final boolean shouldFork(int problemSize, int totalProblemSize) {
-      if (forkJoinPool == null) {
-        return false;
-      }
-      if (getSurplusQueuedTaskCount() > 3) {
-        // Fork tasks if this worker doesn't have more queued work than other workers
-        // See javadocs of #getSurplusQueuedTaskCount for more details
+      if (executor == null) {
         return false;
       }
       if (problemSize == totalProblemSize) {
@@ -249,6 +231,18 @@
       }
       return problemSize > FORK_THRESHOLD;
     }
+
+    @Override
+    public abstract Void call();
+
+    protected final void invokeAll(BaseRecursiveAction... actions) {
+      assert executor != null : "Only call invokeAll if shouldFork returned true";
+      try {
+        executor.invokeAll(Arrays.asList(actions));
+      } catch (IOException e) {
+        throw new UncheckedIOException(e);
+      }
+    }
   }
 
   private class IndexReorderingTask extends BaseRecursiveAction {
@@ -263,8 +257,9 @@
         float[] biases,
         CloseableThreadLocal<PerThreadState> threadLocal,
         BitSet parents,
+        TaskExecutor executor,
         int depth) {
-      super(depth);
+      super(executor, depth);
       this.docIDs = docIDs;
       this.biases = biases;
       this.threadLocal = threadLocal;
@@ -292,7 +287,7 @@
     }
 
     @Override
-    protected void compute() {
+    public Void call() {
       if (depth > 0) {
         Arrays.sort(docIDs.ints, docIDs.offset, docIDs.offset + docIDs.length);
       } else {
@@ -302,7 +297,7 @@
 
       int halfLength = docIDs.length / 2;
       if (halfLength < minPartitionSize) {
-        return;
+        return null;
       }
 
       IntsRef left = new IntsRef(docIDs.ints, docIDs.offset, halfLength);
@@ -349,7 +344,7 @@
           if (split == docIDs.offset) {
             // No good split on the left side either: this slice has a single parent document, no
             // reordering is possible. Stop recursing.
-            return;
+            return null;
           }
         }
 
@@ -362,16 +357,17 @@
       // It is fine for all tasks to share the same docs / biases array since they all work on
       // different slices of the array at a given point in time.
       IndexReorderingTask leftTask =
-          new IndexReorderingTask(left, biases, threadLocal, parents, depth + 1);
+          new IndexReorderingTask(left, biases, threadLocal, parents, executor, depth + 1);
       IndexReorderingTask rightTask =
-          new IndexReorderingTask(right, biases, threadLocal, parents, depth + 1);
+          new IndexReorderingTask(right, biases, threadLocal, parents, executor, depth + 1);
 
       if (shouldFork(docIDs.length, docIDs.ints.length)) {
         invokeAll(leftTask, rightTask);
       } else {
-        leftTask.compute();
-        rightTask.compute();
+        leftTask.call();
+        rightTask.call();
       }
+      return null;
     }
 
     // used for asserts
@@ -422,8 +418,9 @@
               leftDocFreqs,
               rightDocFreqs,
               threadLocal,
+              executor,
               depth)
-          .compute();
+          .call();
 
       if (parents != null) {
         for (int i = docIDs.offset, end = docIDs.offset + docIDs.length; i < end; ) {
@@ -592,8 +589,9 @@
         int[] fromDocFreqs,
         int[] toDocFreqs,
         CloseableThreadLocal<PerThreadState> threadLocal,
+        TaskExecutor executor,
         int depth) {
-      super(depth);
+      super(executor, depth);
       this.docs = docs;
       this.biases = biases;
       this.from = from;
@@ -604,15 +602,15 @@
     }
 
     @Override
-    protected void compute() {
+    public Void call() {
       final int problemSize = to - from;
       if (problemSize > 1 && shouldFork(problemSize, docs.length)) {
         final int mid = (from + to) >>> 1;
         invokeAll(
             new ComputeBiasTask(
-                docs, biases, from, mid, fromDocFreqs, toDocFreqs, threadLocal, depth),
+                docs, biases, from, mid, fromDocFreqs, toDocFreqs, threadLocal, executor, depth),
             new ComputeBiasTask(
-                docs, biases, mid, to, fromDocFreqs, toDocFreqs, threadLocal, depth));
+                docs, biases, mid, to, fromDocFreqs, toDocFreqs, threadLocal, executor, depth));
       } else {
         ForwardIndex forwardIndex = threadLocal.get().forwardIndex;
         try {
@@ -623,6 +621,7 @@
           throw new UncheckedIOException(e);
         }
       }
+      return null;
     }
 
     /**
@@ -707,12 +706,16 @@
   }
 
   private int writePostings(
-      CodecReader reader, Set<String> fields, Directory tempDir, DataOutput postingsOut)
+      CodecReader reader,
+      Set<String> fields,
+      Directory tempDir,
+      DataOutput postingsOut,
+      int parallelism)
       throws IOException {
     final int maxNumTerms =
         (int)
             ((ramBudgetMB * 1024 * 1024 - docRAMRequirements(reader.maxDoc()))
-                / getParallelism()
+                / parallelism
                 / termRAMRequirementsPerThreadPerTerm());
     final int maxDocFreq = (int) ((double) this.maxDocFreq * reader.maxDoc());
 
@@ -825,9 +828,10 @@
   /**
    * Expert: Compute the {@link DocMap} that holds the new doc ID numbering. This is exposed to
    * enable integration into {@link BPReorderingMergePolicy}, {@link #reorder(CodecReader,
-   * Directory)} should be preferred in general.
+   * Directory, Executor)} should be preferred in general.
    */
-  public Sorter.DocMap computeDocMap(CodecReader reader, Directory tempDir) throws IOException {
+  public Sorter.DocMap computeDocMap(CodecReader reader, Directory tempDir, Executor executor)
+      throws IOException {
     if (docRAMRequirements(reader.maxDoc()) >= ramBudgetMB * 1024 * 1024) {
       throw new NotEnoughRAMException(
           "At least "
@@ -847,7 +851,8 @@
       }
     }
 
-    int[] newToOld = computePermutation(reader, fields, tempDir);
+    TaskExecutor taskExecutor = executor == null ? null : new TaskExecutor(executor);
+    int[] newToOld = computePermutation(reader, fields, tempDir, taskExecutor);
     int[] oldToNew = new int[newToOld.length];
     for (int i = 0; i < newToOld.length; ++i) {
       oldToNew[newToOld[i]] = i;
@@ -877,27 +882,42 @@
    * evaluation efficiency. Note that the returned {@link CodecReader} is slow and should typically
    * be used in a call to {@link IndexWriter#addIndexes(CodecReader...)}.
    *
+   * <p>The provided {@link Executor} can be used to perform reordering concurrently. A value of
+   * {@code null} indicates that reordering should be performed in the current thread.
+   *
+   * <p><b>NOTE</b>: The provided {@link Executor} must not reject tasks.
+   *
    * @throws NotEnoughRAMException if not enough RAM is provided
    */
-  public CodecReader reorder(CodecReader reader, Directory tempDir) throws IOException {
-    Sorter.DocMap docMap = computeDocMap(reader, tempDir);
+  public CodecReader reorder(CodecReader reader, Directory tempDir, Executor executor)
+      throws IOException {
+    Sorter.DocMap docMap = computeDocMap(reader, tempDir, executor);
     return SortingCodecReader.wrap(reader, docMap, null);
   }
 
   /**
    * Compute a permutation of the doc ID space that reduces log gaps between consecutive postings.
    */
-  private int[] computePermutation(CodecReader reader, Set<String> fields, Directory dir)
+  private int[] computePermutation(
+      CodecReader reader, Set<String> fields, Directory dir, TaskExecutor executor)
       throws IOException {
     TrackingDirectoryWrapper trackingDir = new TrackingDirectoryWrapper(dir);
 
+    final int parallelism;
+    if (executor == null) {
+      parallelism = 1;
+    } else {
+      // Assume as many threads as processors
+      parallelism = Runtime.getRuntime().availableProcessors();
+    }
+
     final int maxDoc = reader.maxDoc();
     ForwardIndex forwardIndex = null;
     IndexOutput postingsOutput = null;
     boolean success = false;
     try {
       postingsOutput = trackingDir.createTempOutput("postings", "", IOContext.DEFAULT);
-      int numTerms = writePostings(reader, fields, trackingDir, postingsOutput);
+      int numTerms = writePostings(reader, fields, trackingDir, postingsOutput, parallelism);
       CodecUtil.writeFooter(postingsOutput);
       postingsOutput.close();
       final ForwardIndex finalForwardIndex =
@@ -924,14 +944,7 @@
             }
           }) {
         IntsRef docs = new IntsRef(sortedDocs, 0, sortedDocs.length);
-        IndexReorderingTask task =
-            new IndexReorderingTask(docs, new float[maxDoc], threadLocal, parents, 0);
-        if (forkJoinPool != null) {
-          forkJoinPool.execute(task);
-          task.join();
-        } else {
-          task.compute();
-        }
+        new IndexReorderingTask(docs, new float[maxDoc], threadLocal, parents, executor, 0).call();
       }
 
       success = true;
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
index 077b389..5cd3631 100644
--- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
@@ -19,6 +19,7 @@
 import java.io.IOException;
 import java.util.Collections;
 import java.util.Map;
+import java.util.concurrent.Executor;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.FilterMergePolicy;
 import org.apache.lucene.index.MergePolicy;
@@ -129,11 +130,12 @@
             }
 
             @Override
-            public Sorter.DocMap reorder(CodecReader reader, Directory dir) throws IOException {
+            public Sorter.DocMap reorder(CodecReader reader, Directory dir, Executor executor)
+                throws IOException {
               Sorter.DocMap docMap = null;
               if (reader.numDocs() >= minNumDocs) {
                 try {
-                  docMap = reorderer.computeDocMap(reader, dir);
+                  docMap = reorderer.computeDocMap(reader, dir, executor);
                 } catch (
                     @SuppressWarnings("unused")
                     NotEnoughRAMException e) {
diff --git a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java
index f4322ff..b7da308 100644
--- a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java
+++ b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java
@@ -116,11 +116,10 @@
     CodecReader codecReader = SlowCodecReaderWrapper.wrap(leafRealer);
 
     BPIndexReorderer reorderer = new BPIndexReorderer();
-    reorderer.setForkJoinPool(pool);
     reorderer.setMinDocFreq(2);
     reorderer.setMinPartitionSize(1);
     reorderer.setMaxIters(10);
-    CodecReader reordered = reorderer.reorder(codecReader, dir);
+    CodecReader reordered = reorderer.reorder(codecReader, dir, pool);
     String[] ids = new String[codecReader.maxDoc()];
     StoredFields storedFields = reordered.storedFields();
     for (int i = 0; i < codecReader.maxDoc(); ++i) {
@@ -180,11 +179,10 @@
     CodecReader codecReader = SlowCodecReaderWrapper.wrap(leafRealer);
 
     BPIndexReorderer reorderer = new BPIndexReorderer();
-    reorderer.setForkJoinPool(pool);
     reorderer.setMinDocFreq(2);
     reorderer.setMinPartitionSize(1);
     reorderer.setMaxIters(10);
-    CodecReader reordered = reorderer.reorder(codecReader, dir);
+    CodecReader reordered = reorderer.reorder(codecReader, dir, pool);
     StoredFields storedFields = reordered.storedFields();
 
     assertEquals("2", storedFields.document(0).get("id"));
@@ -307,7 +305,7 @@
     reorderer.setMinDocFreq(2);
     reorderer.setMinPartitionSize(1);
     reorderer.setMaxIters(10);
-    CodecReader reordered = reorderer.reorder(codecReader, dir);
+    CodecReader reordered = reorderer.reorder(codecReader, dir, null);
     String[] ids = new String[codecReader.maxDoc()];
     StoredFields storedFields = reordered.storedFields();
     for (int i = 0; i < codecReader.maxDoc(); ++i) {
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MockRandomMergePolicy.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MockRandomMergePolicy.java
index 1b509c7..74f3b87 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MockRandomMergePolicy.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MockRandomMergePolicy.java
@@ -23,6 +23,9 @@
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.FutureTask;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DocValues;
 import org.apache.lucene.index.FilterLeafReader;
@@ -36,6 +39,8 @@
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
 import org.apache.lucene.util.BitSet;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.ThreadInterruptedException;
 
 /** MergePolicy that makes random decisions for testing. */
 public class MockRandomMergePolicy extends MergePolicy {
@@ -241,7 +246,8 @@
     }
 
     @Override
-    public Sorter.DocMap reorder(CodecReader reader, Directory dir) throws IOException {
+    public Sorter.DocMap reorder(CodecReader reader, Directory dir, Executor executor)
+        throws IOException {
       if (r.nextBoolean()) {
         if (LuceneTestCase.VERBOSE) {
           System.out.println("NOTE: MockRandomMergePolicy now reverses reader=" + reader);
@@ -249,6 +255,19 @@
         // Reverse the doc ID order
         return reverse(reader);
       }
+      if (executor != null && r.nextBoolean()) {
+        // submit random work to the executor
+        Runnable dummyRunnable = () -> {};
+        FutureTask<Void> task = new FutureTask<>(dummyRunnable, null);
+        executor.execute(task);
+        try {
+          task.get();
+        } catch (InterruptedException e) {
+          throw new ThreadInterruptedException(e);
+        } catch (ExecutionException e) {
+          throw IOUtils.rethrowAlways(e.getCause());
+        }
+      }
       return null;
     }
   }