LUCENE-9302: Grouping to use long to avoid overflows
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
index 23601ca..dad1101 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
@@ -353,8 +353,8 @@
return new TopGroups<>(new TopGroups<>(groupSort.getSort(),
withinGroupSort.getSort(),
- totalHitCount, totalGroupedHitCount, groups, maxScore),
- totalGroupCount);
+ (long) totalHitCount, (long) totalGroupedHitCount, groups, maxScore),
+ (long) totalGroupCount);
}
@Override
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
index b88fb74..6ac5dc1 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupingSearch.java
@@ -161,7 +161,7 @@
}
if (allGroups) {
- return new TopGroups(secondPassCollector.getTopGroups(groupDocsOffset), matchingGroups.size());
+ return new TopGroups(secondPassCollector.getTopGroups(groupDocsOffset), (long) matchingGroups.size());
} else {
return secondPassCollector.getTopGroups(groupDocsOffset);
}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
index cb84400..d941a1a 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
@@ -29,13 +29,13 @@
* @lucene.experimental */
public class TopGroups<T> {
/** Number of documents matching the search */
- public final int totalHitCount;
+ public final long totalHitCount;
/** Number of documents grouped into the topN groups */
- public final int totalGroupedHitCount;
+ public final long totalGroupedHitCount;
/** The total number of unique groups. If <code>null</code> this value is not computed. */
- public final Integer totalGroupCount;
+ public final Long totalGroupCount;
/** Group results in groupSort order */
public final GroupDocs<T>[] groups;
@@ -50,7 +50,7 @@
* <code>Float.NaN</code> if scores were not computed. */
public final float maxScore;
- public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, int totalHitCount, int totalGroupedHitCount, GroupDocs<T>[] groups, float maxScore) {
+ public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, long totalHitCount, long totalGroupedHitCount, GroupDocs<T>[] groups, float maxScore) {
this.groupSort = groupSort;
this.withinGroupSort = withinGroupSort;
this.totalHitCount = totalHitCount;
@@ -60,7 +60,7 @@
this.maxScore = maxScore;
}
- public TopGroups(TopGroups<T> oldTopGroups, Integer totalGroupCount) {
+ public TopGroups(TopGroups<T> oldTopGroups, Long totalGroupCount) {
this.groupSort = oldTopGroups.groupSort;
this.withinGroupSort = oldTopGroups.withinGroupSort;
this.totalHitCount = oldTopGroups.totalHitCount;
@@ -118,10 +118,10 @@
return null;
}
- int totalHitCount = 0;
- int totalGroupedHitCount = 0;
+ long totalHitCount = 0;
+ long totalGroupedHitCount = 0;
// Optionally merge the totalGroupCount.
- Integer totalGroupCount = null;
+ Long totalGroupCount = null;
final int numGroups = shardGroups[0].groups.length;
for(TopGroups<T> shard : shardGroups) {
@@ -132,7 +132,7 @@
totalGroupedHitCount += shard.totalGroupedHitCount;
if (shard.totalGroupCount != null) {
if (totalGroupCount == null) {
- totalGroupCount = 0;
+ totalGroupCount = 0L;
}
totalGroupCount += shard.totalGroupCount;
@@ -154,7 +154,7 @@
final T groupValue = shardGroups[0].groups[groupIDX].groupValue;
//System.out.println(" merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues));
float maxScore = Float.NaN;
- int totalHits = 0;
+ long totalHits = 0;
double scoreSum = 0.0;
for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) {
//System.out.println(" shard=" + shardIDX);
diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
index f1ce508..73832c1 100644
--- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
+++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java
@@ -452,7 +452,7 @@
final List<BytesRef> sortedGroups = new ArrayList<>();
final List<Comparable<?>[]> sortedGroupFields = new ArrayList<>();
- int totalHitCount = 0;
+ long totalHitCount = 0;
Set<BytesRef> knownGroups = new HashSet<>();
//System.out.println("TEST: slowGrouping");
@@ -492,7 +492,7 @@
final Comparator<GroupDoc> docSortComp = getComparator(docSort);
@SuppressWarnings({"unchecked","rawtypes"})
final GroupDocs<BytesRef>[] result = new GroupDocs[limit-groupOffset];
- int totalGroupedHitCount = 0;
+ long totalGroupedHitCount = 0;
for(int idx=groupOffset;idx < limit;idx++) {
final BytesRef group = sortedGroups.get(idx);
final List<GroupDoc> docs = groups.get(group);
@@ -523,7 +523,7 @@
if (doAllGroups) {
return new TopGroups<>(
new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN),
- knownGroups.size()
+ (long) knownGroups.size()
);
} else {
return new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN);
@@ -960,7 +960,7 @@
if (doAllGroups) {
TopGroups<BytesRef> tempTopGroups = getTopGroups(c2, docOffset);
- groupsResult = new TopGroups<>(tempTopGroups, allGroupsCollector.getGroupCount());
+ groupsResult = new TopGroups<>(tempTopGroups, (long) allGroupsCollector.getGroupCount());
} else {
groupsResult = getTopGroups(c2, docOffset);
}
@@ -1046,8 +1046,8 @@
final TopGroups<BytesRef> tempTopGroupsBlocks = (TopGroups<BytesRef>) c3.getTopGroups(docSort, groupOffset, docOffset, docOffset+docsPerGroup);
final TopGroups<BytesRef> groupsResultBlocks;
if (doAllGroups && tempTopGroupsBlocks != null) {
- assertEquals((int) tempTopGroupsBlocks.totalGroupCount, allGroupsCollector2.getGroupCount());
- groupsResultBlocks = new TopGroups<>(tempTopGroupsBlocks, allGroupsCollector2.getGroupCount());
+ assertEquals((long) tempTopGroupsBlocks.totalGroupCount, (long) allGroupsCollector2.getGroupCount());
+ groupsResultBlocks = new TopGroups<>(tempTopGroupsBlocks, (long) allGroupsCollector2.getGroupCount());
} else {
groupsResultBlocks = tempTopGroupsBlocks;
}