blob: ea84f4cdb5144d0e88fd15ba168756cd80755aa2 [file] [log] [blame]
From d1dbca13ace135626af9ecd4cfaf1c49912117cd Mon Sep 17 00:00:00 2001
From: Julien Massenet <julien.massenet@mail.rakuten.com>
Date: Mon, 6 Mar 2017 14:45:16 +0100
Subject: [PATCH] maxScore merging happens as expected during distributed
grouped search
---
.../apache/lucene/search/grouping/TopGroups.java | 15 +-
.../lucene/search/grouping/TopGroupsTest.java | 264 +++++++++++++++++++++
2 files changed, 275 insertions(+), 4 deletions(-)
create mode 100644 lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
index 36ab8d9b07..209deec6bd 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java
@@ -133,12 +133,12 @@ public class TopGroups<T> {
} else {
shardTopDocs = new TopFieldDocs[shardGroups.length];
}
- float totalMaxScore = Float.MIN_VALUE;
+ float totalMaxScore = Float.NaN;
for(int groupIDX=0;groupIDX<numGroups;groupIDX++) {
final T groupValue = shardGroups[0].groups[groupIDX].groupValue;
//System.out.println(" merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues));
- float maxScore = Float.MIN_VALUE;
+ float maxScore = Float.NaN;
int totalHits = 0;
double scoreSum = 0.0;
for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) {
@@ -169,7 +169,7 @@ public class TopGroups<T> {
docSort.getSort(),
shardGroupDocs.maxScore);
}
- maxScore = Math.max(maxScore, shardGroupDocs.maxScore);
+ maxScore = max(maxScore, shardGroupDocs.maxScore);
totalHits += shardGroupDocs.totalHits;
scoreSum += shardGroupDocs.score;
}
@@ -222,7 +222,7 @@ public class TopGroups<T> {
mergedScoreDocs,
groupValue,
shardGroups[0].groups[groupIDX].groupSortValues);
- totalMaxScore = Math.max(totalMaxScore, maxScore);
+ totalMaxScore = max(totalMaxScore, maxScore);
}
if (totalGroupCount != null) {
@@ -242,4 +242,11 @@ public class TopGroups<T> {
totalMaxScore);
}
}
+
+ public static float max(float a, float b) {
+ if (Float.isNaN(a)) {
+ return Float.isNaN(b) ? Float.NaN : b;
+ }
+ return Float.isNaN(b) ? a : Math.max(a, b);
+ }
}
diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java
new file mode 100644
index 0000000000..194fcff7a2
--- /dev/null
+++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java
@@ -0,0 +1,264 @@
+package org.apache.lucene.search.grouping;
+
+import junit.framework.TestCase;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.util.LuceneTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public final class TopGroupsTest extends LuceneTestCase {
+ public void testMax() {
+ assertTrue(Float.isNaN(TopGroups.max(Float.NaN, Float.NaN)));
+ assertEquals(2.0f, TopGroups.max(2.0f, Float.NaN), 0.1f);
+ assertEquals(2.0f, TopGroups.max(Float.NaN, 2.0f), 0.1f);
+ assertEquals(2.0f, TopGroups.max(1.0f, 2.0f), 0.1f);
+ assertEquals(2.0f, TopGroups.max(2.0f, 1.0f), 0.1f);
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testMerge() throws IOException {
+ //Given the following groups received from the shards
+ TopGroups<Integer> shard1 = newTopGroup()
+ .withGroup(newGroup().withValue(1).build())
+ .withGroup(newGroup().withValue(2).build())
+ .build();
+
+ TopGroups<Integer> shard2 = newTopGroup()
+ .withGroup(newGroup().withValue(1)
+ .withDoc(newScoreDoc().withId(5).withScore(2.0f).withShardId(1).build())
+ .withDoc(newScoreDoc().withId(1).withScore(1.0f).withShardId(1).build())
+ .build())
+ .withGroup(newGroup().withValue(2)
+ .withDoc(newScoreDoc().withId(10).withScore(3.0f).withShardId(1).build())
+ .build())
+ .build();
+
+ TopGroups<Integer> shard3 = newTopGroup()
+ .withGroup(newGroup().withValue(1)
+ .withDoc(newScoreDoc().withId(1).withScore(4.0f).withShardId(2).build())
+ .build())
+ .withGroup(newGroup().withValue(2).build())
+ .build();
+
+ TopGroups<Integer> shard4 = newTopGroup()
+ .withGroup(newGroup().withValue(1).build())
+ .withGroup(newGroup().withValue(2).build())
+ .build();
+
+ // When merging the groups
+ TopGroups<Integer> merged = TopGroups.merge(
+ new TopGroups[]{shard1, shard2, shard3, shard4}, // shardGroups
+ Sort.RELEVANCE, // groupSort
+ Sort.RELEVANCE, // docSort
+ 0, // docOffset
+ 2, // docTopN
+ TopGroups.ScoreMergeMode.None // scoreMergeMode
+ );
+
+ // Expect the following results
+ assertNotNull(merged);
+ assertEquals(4, merged.totalHitCount);
+ assertEquals(4, merged.totalGroupedHitCount);
+ assertEquals(4.0f, merged.maxScore, 0.1f);
+ assertEquals(2, merged.groups.length);
+
+ {
+ GroupDocs<Integer> group1 = merged.groups[0];
+ assertGroupDocs(group1)
+ .withValue(1)
+ .withMaxScore(4.0f)
+ .withNbDocs(2)
+ .withTotalHits(3);
+
+ ScoreDoc group1Doc1 = group1.scoreDocs[0];
+ assertScoreDoc(group1Doc1)
+ .withId(1)
+ .withShardId(2)
+ .withScore(4.0f);
+
+ ScoreDoc group1Doc2 = group1.scoreDocs[1];
+ assertScoreDoc(group1Doc2)
+ .withId(5)
+ .withShardId(1)
+ .withScore(2.0f);
+ }
+
+ {
+ GroupDocs<Integer> group2 = merged.groups[1];
+ assertGroupDocs(group2)
+ .withValue(2)
+ .withMaxScore(3.0f)
+ .withNbDocs(1).
+ withTotalHits(1);
+
+ ScoreDoc group2Doc1 = group2.scoreDocs[0];
+ assertScoreDoc(group2Doc1)
+ .withId(10)
+ .withShardId(1)
+ .withScore(3.0f);
+ }
+ }
+
+ // --------------------------------------------
+ // Helpers
+ // --------------------------------------------
+
+ private static TopGroupBuilder<Integer> newTopGroup() {
+ return new TopGroupBuilder<Integer>();
+ }
+
+ private static GroupDocsBuilder<Integer> newGroup() {
+ return new GroupDocsBuilder<Integer>();
+ }
+
+ private static ScoreDocBuilder newScoreDoc() {
+ return new ScoreDocBuilder();
+ }
+
+ private static <T> GroupDocsAssert<T> assertGroupDocs(GroupDocs<T> actual) {
+ return new GroupDocsAssert<T>(actual);
+ }
+
+ private static ScoreDocAssert assertScoreDoc(ScoreDoc actual) {
+ return new ScoreDocAssert(actual);
+ }
+
+ private static final class TopGroupBuilder<T> {
+ private final List<GroupDocs<T>> groups = new ArrayList<GroupDocs<T>>();
+
+ TopGroupBuilder<T> withGroup(GroupDocs<T> group) {
+ groups.add(group);
+ return this;
+ }
+
+ @SuppressWarnings("unchecked")
+ TopGroups<T> build() {
+ int totalHitCount = 0;
+ float maxScore = Float.NaN;
+ for (GroupDocs<T> group : groups) {
+ totalHitCount += group.totalHits;
+ maxScore = TopGroups.max(maxScore, group.maxScore);
+ }
+ return new TopGroups<T>(
+ new SortField[]{SortField.FIELD_SCORE}, // groupSort
+ new SortField[]{SortField.FIELD_SCORE}, // withinGroupSort
+ totalHitCount, // totalHitCount
+ totalHitCount, // totalGroupedHitCount
+ groups.toArray(new GroupDocs[groups.size()]), // groups
+ maxScore // maxScore
+ );
+ }
+ }
+
+ private static final class GroupDocsBuilder<T> {
+ private T value;
+ private final List<ScoreDoc> docs = new ArrayList<ScoreDoc>();
+
+ GroupDocsBuilder<T> withValue(T value) {
+ this.value = value;
+ return this;
+ }
+
+ GroupDocsBuilder<T> withDoc(ScoreDoc doc) {
+ this.docs.add(doc);
+ return this;
+ }
+
+ GroupDocs<T> build() {
+ float maxScore = Float.NaN;
+ int totalHits = 0;
+ for (ScoreDoc doc : docs) {
+ maxScore = TopGroups.max(maxScore, doc.score);
+ totalHits++;
+ }
+ return new GroupDocs<T>(
+ Float.NaN, // score
+ maxScore, // maxScore
+ totalHits, // totalHits
+ docs.toArray(new ScoreDoc[docs.size()]), // scoreDocs
+ value, // groupValue
+ null // groupSortValue
+ );
+ }
+ }
+
+ private static final class ScoreDocBuilder {
+ private int id;
+ private int shardId;
+ private float score;
+
+ ScoreDocBuilder withId(int id) {
+ this.id = id;
+ return this;
+ }
+
+ ScoreDocBuilder withShardId(int shardId) {
+ this.shardId = shardId;
+ return this;
+ }
+
+ ScoreDocBuilder withScore(float score) {
+ this.score = score;
+ return this;
+ }
+
+ ScoreDoc build() {
+ return new ScoreDoc(id, score, shardId);
+ }
+ }
+
+ private static final class GroupDocsAssert<T> {
+ private final GroupDocs<T> actual;
+
+ GroupDocsAssert(GroupDocs<T> actual) {
+ this.actual = actual;
+ }
+
+ GroupDocsAssert<T> withValue(T expected) {
+ assertEquals("Invalid group value", expected, actual.groupValue);
+ return this;
+ }
+
+ GroupDocsAssert<T> withTotalHits(int expected) {
+ assertEquals("Invalid number of total hits", expected, actual.totalHits);
+ return this;
+ }
+
+ GroupDocsAssert<T> withMaxScore(float expected) {
+ assertEquals("Invalid maximum score", expected, actual.maxScore, 0.1f);
+ return this;
+ }
+
+ GroupDocsAssert<T> withNbDocs(int expected) {
+ assertEquals("Invalid number of score docs", expected, actual.scoreDocs.length);
+ return this;
+ }
+ }
+
+ private static final class ScoreDocAssert {
+ private final ScoreDoc actual;
+
+ ScoreDocAssert(ScoreDoc actual) {
+ this.actual = actual;
+ }
+
+ ScoreDocAssert withId(int expected) {
+ assertEquals("Invalid doc ID", expected, actual.doc);
+ return this;
+ }
+
+ ScoreDocAssert withScore(float expected) {
+ assertEquals("Invalid score", expected, actual.score, 0.1f);
+ return this;
+ }
+
+ ScoreDocAssert withShardId(int expected) {
+ assertEquals("Invalid shard ID", expected, actual.shardIndex);
+ return this;
+ }
+ }
+}
\ No newline at end of file
--
2.11.0