Fix bug which produces vastly inaccurate query results when forceLimitPushDown is enabled and order by clause has non grouping fields (#11097)

diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java
index f170550..6f1cee9 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java
@@ -975,7 +975,8 @@
           query.getDimensions(),
           getDimensionComparators(limitSpec),
           query.getResultRowHasTimestamp(),
-          query.getContextSortByDimsFirst()
+          query.getContextSortByDimsFirst(),
+          keySize
       );
     }
 
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GrouperBufferComparatorUtils.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GrouperBufferComparatorUtils.java
index f8d8488..467a729 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GrouperBufferComparatorUtils.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GrouperBufferComparatorUtils.java
@@ -126,7 +126,8 @@
       List<DimensionSpec> dimensions,
       Grouper.BufferComparator[] dimComparators,
       boolean includeTimestamp,
-      boolean sortByDimsFirst
+      boolean sortByDimsFirst,
+      int keyBufferTotalSize
   )
   {
     int dimCount = dimensions.size();
@@ -148,7 +149,8 @@
         if (aggIndex >= 0) {
           final StringComparator stringComparator = orderSpec.getDimensionComparator();
           final ValueType valueType = aggregatorFactories[aggIndex].getType();
-          final int aggOffset = aggregatorOffsets[aggIndex] - Integer.BYTES;
+          // Aggregators start after dimensions
+          final int aggOffset = keyBufferTotalSize + aggregatorOffsets[aggIndex];
 
           aggCount++;
 
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
index f812f14..8a3e2a4 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
@@ -1252,7 +1252,10 @@
           dimensions,
           serdeHelperComparators,
           includeTimestamp,
-          sortByDimsFirst
+          sortByDimsFirst,
+          Arrays.stream(serdeHelpers)
+                .mapToInt(RowBasedKeySerdeHelper::getKeyBufferValueSize)
+                .sum()
       );
     }
 
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByLimitPushDownMultiNodeMergeTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByLimitPushDownMultiNodeMergeTest.java
index 3c62e11..7cf7e65 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByLimitPushDownMultiNodeMergeTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByLimitPushDownMultiNodeMergeTest.java
@@ -30,6 +30,7 @@
 import org.apache.druid.collections.CloseableStupidPool;
 import org.apache.druid.data.input.InputRow;
 import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.data.input.impl.DimensionSchema;
 import org.apache.druid.data.input.impl.DimensionsSpec;
 import org.apache.druid.data.input.impl.LongDimensionSchema;
 import org.apache.druid.data.input.impl.StringDimensionSchema;
@@ -141,6 +142,14 @@
 
   private IncrementalIndex makeIncIndex(boolean withRollup)
   {
+    return makeIncIndex(withRollup, Arrays.asList(
+        new StringDimensionSchema("dimA"),
+        new LongDimensionSchema("metA")
+    ));
+  }
+
+  private IncrementalIndex makeIncIndex(boolean withRollup, List<DimensionSchema> dimensions)
+  {
     return new OnheapIncrementalIndex.Builder()
         .setIndexSchema(
             new IncrementalIndexSchema.Builder()
@@ -311,7 +320,217 @@
     );
     QueryableIndex qindexD = INDEX_IO.loadIndex(fileD);
 
-    groupByIndices = Arrays.asList(qindexA, qindexB, qindexC, qindexD);
+    List<String> dimNames2 = Arrays.asList("dimA", "dimB", "metA");
+    List<DimensionSchema> dimensions = Arrays.asList(
+        new StringDimensionSchema("dimA"),
+        new StringDimensionSchema("dimB"),
+        new LongDimensionSchema("metA")
+    );
+    final IncrementalIndex indexE = makeIncIndex(false, dimensions);
+    incrementalIndices.add(indexE);
+
+    event = new HashMap<>();
+    event.put("dimA", "pomegranate");
+    event.put("dimB", "raw");
+    event.put("metA", 5L);
+    row = new MapBasedInputRow(1505260800000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "mango");
+    event.put("dimB", "ripe");
+    event.put("metA", 9L);
+    row = new MapBasedInputRow(1605260800000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "pomegranate");
+    event.put("dimB", "raw");
+    event.put("metA", 3L);
+    row = new MapBasedInputRow(1705264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "mango");
+    event.put("dimB", "ripe");
+    event.put("metA", 7L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "grape");
+    event.put("dimB", "raw");
+    event.put("metA", 5L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "apple");
+    event.put("dimB", "ripe");
+    event.put("metA", 3L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "apple");
+    event.put("dimB", "raw");
+    event.put("metA", 1L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "apple");
+    event.put("dimB", "ripe");
+    event.put("metA", 4L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "apple");
+    event.put("dimB", "raw");
+    event.put("metA", 1L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "banana");
+    event.put("dimB", "ripe");
+    event.put("metA", 4L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "orange");
+    event.put("dimB", "raw");
+    event.put("metA", 9L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "peach");
+    event.put("dimB", "ripe");
+    event.put("metA", 7L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "orange");
+    event.put("dimB", "raw");
+    event.put("metA", 2L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "strawberry");
+    event.put("dimB", "ripe");
+    event.put("metA", 10L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexE.add(row);
+
+    final File fileE = INDEX_MERGER_V9.persist(
+        indexE,
+        new File(tmpDir, "E"),
+        new IndexSpec(),
+        null
+    );
+    QueryableIndex qindexE = INDEX_IO.loadIndex(fileE);
+
+    final IncrementalIndex indexF = makeIncIndex(false, dimensions);
+    incrementalIndices.add(indexF);
+
+    event = new HashMap<>();
+    event.put("dimA", "kiwi");
+    event.put("dimB", "raw");
+    event.put("metA", 7L);
+    row = new MapBasedInputRow(1505260800000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "watermelon");
+    event.put("dimB", "ripe");
+    event.put("metA", 14L);
+    row = new MapBasedInputRow(1605260800000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "kiwi");
+    event.put("dimB", "raw");
+    event.put("metA", 8L);
+    row = new MapBasedInputRow(1705264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "kiwi");
+    event.put("dimB", "ripe");
+    event.put("metA", 8L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "lemon");
+    event.put("dimB", "raw");
+    event.put("metA", 3L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "cherry");
+    event.put("dimB", "ripe");
+    event.put("metA", 2L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "cherry");
+    event.put("dimB", "raw");
+    event.put("metA", 7L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "avocado");
+    event.put("dimB", "ripe");
+    event.put("metA", 12L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "cherry");
+    event.put("dimB", "raw");
+    event.put("metA", 3L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "plum");
+    event.put("dimB", "ripe");
+    event.put("metA", 5L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "plum");
+    event.put("dimB", "raw");
+    event.put("metA", 3L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    event = new HashMap<>();
+    event.put("dimA", "lime");
+    event.put("dimB", "ripe");
+    event.put("metA", 7L);
+    row = new MapBasedInputRow(1805264400000L, dimNames2, event);
+    indexF.add(row);
+
+    final File fileF = INDEX_MERGER_V9.persist(
+        indexF,
+        new File(tmpDir, "F"),
+        new IndexSpec(),
+        null
+    );
+    QueryableIndex qindexF = INDEX_IO.loadIndex(fileF);
+
+    groupByIndices = Arrays.asList(qindexA, qindexB, qindexC, qindexD, qindexE, qindexF);
     resourceCloser = Closer.create();
     setupGroupByFactory();
   }
@@ -704,6 +923,95 @@
     Assert.assertEquals(expectedRow3, results.get(3));
   }
 
+  @Test
+  public void testForcePushLimitDownAccuracyWhenSortHasNonGroupingFields()
+  {
+    // The two testing segments have non overlapping groups, so the result should be 100% accurate even
+    // forceLimitPushDown is applied
+    List<ResultRow> resultsWithoutLimitPushDown = testForcePushLimitDownAccuracyWhenSortHasNonGroupingFieldsHelper(ImmutableMap.of());
+    List<ResultRow> resultsWithLimitPushDown = testForcePushLimitDownAccuracyWhenSortHasNonGroupingFieldsHelper(ImmutableMap.of(
+        GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true,
+        GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, true
+    ));
+
+    List<ResultRow> expectedResults = ImmutableList.of(
+        ResultRow.of("mango", "ripe", 16),
+        ResultRow.of("kiwi", "raw", 15),
+        ResultRow.of("watermelon", "ripe", 14),
+        ResultRow.of("avocado", "ripe", 12),
+        ResultRow.of("orange", "raw", 11)
+    );
+
+    Assert.assertEquals(expectedResults.toString(), resultsWithoutLimitPushDown.toString());
+    Assert.assertEquals(expectedResults.toString(), resultsWithLimitPushDown.toString());
+  }
+
+  private List<ResultRow> testForcePushLimitDownAccuracyWhenSortHasNonGroupingFieldsHelper(Map<String, Object> context)
+  {
+    QueryToolChest<ResultRow, GroupByQuery> toolChest = groupByFactory.getToolchest();
+    QueryRunner<ResultRow> theRunner = new FinalizeResultsQueryRunner<>(
+        toolChest.mergeResults(
+            groupByFactory.mergeRunners(executorService, getRunner1(4))
+        ),
+        (QueryToolChest) toolChest
+    );
+
+    QueryRunner<ResultRow> theRunner2 = new FinalizeResultsQueryRunner<>(
+        toolChest.mergeResults(
+            groupByFactory2.mergeRunners(executorService, getRunner2(5))
+        ),
+        (QueryToolChest) toolChest
+    );
+
+    QueryRunner<ResultRow> finalRunner = new FinalizeResultsQueryRunner<>(
+        toolChest.mergeResults(
+            new QueryRunner<ResultRow>()
+            {
+              @Override
+              public Sequence<ResultRow> run(QueryPlus<ResultRow> queryPlus, ResponseContext responseContext)
+              {
+                return Sequences
+                    .simple(
+                        ImmutableList.of(
+                            theRunner.run(queryPlus, responseContext),
+                            theRunner2.run(queryPlus, responseContext)
+                        )
+                    )
+                    .flatMerge(Function.identity(), queryPlus.getQuery().getResultOrdering());
+              }
+            }
+        ),
+        (QueryToolChest) toolChest
+    );
+
+    QuerySegmentSpec intervalSpec = new MultipleIntervalSegmentSpec(
+        Collections.singletonList(Intervals.utc(1500000000000L, 1900000000000L))
+    );
+
+    DefaultLimitSpec ls = new DefaultLimitSpec(
+        Collections.singletonList(
+            new OrderByColumnSpec("a0", OrderByColumnSpec.Direction.DESCENDING, StringComparators.NUMERIC)
+        ),
+        5
+    );
+
+    GroupByQuery query = GroupByQuery
+        .builder()
+        .setDataSource("blah")
+        .setQuerySegmentSpec(intervalSpec)
+        .setDimensions(
+            new DefaultDimensionSpec("dimA", "d0", ValueType.STRING),
+            new DefaultDimensionSpec("dimB", "d1", ValueType.STRING)
+        ).setAggregatorSpecs(new LongSumAggregatorFactory("a0", "metA"))
+        .setLimitSpec(ls)
+        .setContext(context)
+        .setGranularity(Granularities.ALL)
+        .build();
+
+    Sequence<ResultRow> queryResult = finalRunner.run(QueryPlus.wrap(query), ResponseContext.createEmpty());
+    return queryResult.toList();
+  }
+
   private List<QueryRunner<ResultRow>> getRunner1(int qIndexNumber)
   {
     List<QueryRunner<ResultRow>> runners = new ArrayList<>();