Add query interruption flag check to broker groupby reduction (#9499)

* add query interruption flag check to broker groupby reduction

* add query interruption flag check to broker groupby reduction

* add query interruption flag check to broker groupby reduction

* add benchmark

* tiled loop

* add benchmark

* Trigger Test
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
index 7a71a59..5b27b2d 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
@@ -694,9 +694,9 @@
         LOGGER.debug("Remove track of running query: {}", requestId);
       }
     } else {
-      brokerResponse =
-          processBrokerRequest(requestId, brokerRequest, serverBrokerRequest, offlineBrokerRequest, offlineRoutingTable,
-              realtimeBrokerRequest, realtimeRoutingTable, remainingTimeMs, serverStats, requestContext);
+      brokerResponse = processBrokerRequest(requestId, brokerRequest, serverBrokerRequest, offlineBrokerRequest,
+          offlineRoutingTable, realtimeBrokerRequest, realtimeRoutingTable, remainingTimeMs, serverStats,
+          requestContext);
     }
 
     brokerResponse.setExceptions(exceptions);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index 0bf49c7..cfa1621 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -64,6 +64,7 @@
 @SuppressWarnings({"rawtypes", "unchecked"})
 public class GroupByDataTableReducer implements DataTableReducer {
   private static final int MIN_DATA_TABLES_FOR_CONCURRENT_REDUCE = 2; // TBD, find a better value.
+  private static final int MAX_ROWS_UPSERT_PER_INTERRUPTION_CHECK = 10_000;
 
   private final QueryContext _queryContext;
   private final AggregationFunction[] _aggregationFunctions;
@@ -288,51 +289,57 @@
               }
 
               int numRows = dataTable.getNumberOfRows();
-              for (int rowId = 0; rowId < numRows; rowId++) {
-                Object[] values = new Object[_numColumns];
-                for (int colId = 0; colId < _numColumns; colId++) {
-                  switch (storedColumnDataTypes[colId]) {
-                    case INT:
-                      values[colId] = dataTable.getInt(rowId, colId);
-                      break;
-                    case LONG:
-                      values[colId] = dataTable.getLong(rowId, colId);
-                      break;
-                    case FLOAT:
-                      values[colId] = dataTable.getFloat(rowId, colId);
-                      break;
-                    case DOUBLE:
-                      values[colId] = dataTable.getDouble(rowId, colId);
-                      break;
-                    case BIG_DECIMAL:
-                      values[colId] = dataTable.getBigDecimal(rowId, colId);
-                      break;
-                    case STRING:
-                      values[colId] = dataTable.getString(rowId, colId);
-                      break;
-                    case BYTES:
-                      values[colId] = dataTable.getBytes(rowId, colId);
-                      break;
-                    case OBJECT:
-                      // TODO: Move ser/de into AggregationFunction interface
-                      DataTable.CustomObject customObject = dataTable.getCustomObject(rowId, colId);
-                      if (customObject != null) {
-                        values[colId] = ObjectSerDeUtils.deserialize(customObject);
-                      }
-                      break;
-                    // Add other aggregation intermediate result / group-by column type supports here
-                    default:
-                      throw new IllegalStateException();
-                  }
+              for (int rowIdBatch = 0; rowIdBatch < numRows; rowIdBatch += MAX_ROWS_UPSERT_PER_INTERRUPTION_CHECK) {
+                if (Thread.interrupted()) {
+                  return;
                 }
-                if (nullHandlingEnabled) {
+                int upper = Math.min(rowIdBatch + MAX_ROWS_UPSERT_PER_INTERRUPTION_CHECK, numRows);
+                for (int rowId = rowIdBatch; rowId < upper; rowId++) {
+                  Object[] values = new Object[_numColumns];
                   for (int colId = 0; colId < _numColumns; colId++) {
-                    if (nullBitmaps[colId] != null && nullBitmaps[colId].contains(rowId)) {
-                      values[colId] = null;
+                    switch (storedColumnDataTypes[colId]) {
+                      case INT:
+                        values[colId] = dataTable.getInt(rowId, colId);
+                        break;
+                      case LONG:
+                        values[colId] = dataTable.getLong(rowId, colId);
+                        break;
+                      case FLOAT:
+                        values[colId] = dataTable.getFloat(rowId, colId);
+                        break;
+                      case DOUBLE:
+                        values[colId] = dataTable.getDouble(rowId, colId);
+                        break;
+                      case BIG_DECIMAL:
+                        values[colId] = dataTable.getBigDecimal(rowId, colId);
+                        break;
+                      case STRING:
+                        values[colId] = dataTable.getString(rowId, colId);
+                        break;
+                      case BYTES:
+                        values[colId] = dataTable.getBytes(rowId, colId);
+                        break;
+                      case OBJECT:
+                        // TODO: Move ser/de into AggregationFunction interface
+                        DataTable.CustomObject customObject = dataTable.getCustomObject(rowId, colId);
+                        if (customObject != null) {
+                          values[colId] = ObjectSerDeUtils.deserialize(customObject);
+                        }
+                        break;
+                      // Add other aggregation intermediate result / group-by column type supports here
+                      default:
+                        throw new IllegalStateException();
                     }
                   }
+                  if (nullHandlingEnabled) {
+                    for (int colId = 0; colId < _numColumns; colId++) {
+                      if (nullBitmaps[colId] != null && nullBitmaps[colId].contains(rowId)) {
+                        values[colId] = null;
+                      }
+                    }
+                  }
+                  indexedTable.upsert(new Record(values));
                 }
-                indexedTable.upsert(new Record(values));
               }
             } finally {
               countDownLatch.countDown();
diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkThreadInterruptionCheck.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkThreadInterruptionCheck.java
new file mode 100644
index 0000000..ea48694
--- /dev/null
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkThreadInterruptionCheck.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.perf;
+
+import java.util.concurrent.TimeUnit;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.infra.Blackhole;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.Options;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import org.openjdk.jmh.runner.options.TimeValue;
+
+
+@State(Scope.Benchmark)
+public class BenchmarkThreadInterruptionCheck {
+
+  static final int MAX_ROWS_UPSERT_PER_INTERRUPTION_CHECK_MASK = 0b111_11111_11111;
+
+  public static void main(String[] args)
+      throws RunnerException {
+    Options opt =
+        new OptionsBuilder().include(BenchmarkThreadInterruptionCheck.class.getSimpleName())
+            .warmupTime(TimeValue.seconds(5))
+            .warmupIterations(3).measurementTime(TimeValue.seconds(5)).measurementIterations(5).forks(1).build();
+
+    new Runner(opt).run();
+  }
+
+  @Benchmark
+  @BenchmarkMode(Mode.AverageTime)
+  @OutputTimeUnit(TimeUnit.MILLISECONDS)
+  public void benchMaskingTime(Blackhole bh) {
+    for (int i = 0; i < 1000000; i++) {
+      bh.consume((i & MAX_ROWS_UPSERT_PER_INTERRUPTION_CHECK_MASK) == 0);
+    }
+  }
+
+  @Benchmark
+  @BenchmarkMode(Mode.AverageTime)
+  @OutputTimeUnit(TimeUnit.MILLISECONDS)
+  public void benchModuloTime(Blackhole bh) {
+    for (int i = 0; i < 1000000; i++) {
+      bh.consume((i % 16321) == 0);
+    }
+  }
+
+  @Benchmark
+  @BenchmarkMode(Mode.AverageTime)
+  @OutputTimeUnit(TimeUnit.MILLISECONDS)
+  public void benchLoopTilingTime(Blackhole bh) {
+    for (int i = 0; i < 1000000; i += 16321) {
+      bh.consume(Math.min(i + 16321, 1000000));
+    }
+  }
+
+  @Benchmark
+  @BenchmarkMode(Mode.AverageTime)
+  @OutputTimeUnit(TimeUnit.MILLISECONDS)
+  public void benchInterruptionCheckTime(Blackhole bh) {
+    for (int i = 0; i < 1000000; i++) {
+      bh.consume(Thread.interrupted());
+    }
+  }
+}