| From d1dbca13ace135626af9ecd4cfaf1c49912117cd Mon Sep 17 00:00:00 2001 |
| From: Julien Massenet <julien.massenet@mail.rakuten.com> |
| Date: Mon, 6 Mar 2017 14:45:16 +0100 |
| Subject: [PATCH] maxScore merging happens as expected during distributed |
| grouped search |
| |
| --- |
| .../apache/lucene/search/grouping/TopGroups.java | 15 +- |
| .../lucene/search/grouping/TopGroupsTest.java | 264 +++++++++++++++++++++ |
| 2 files changed, 275 insertions(+), 4 deletions(-) |
| create mode 100644 lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java |
| |
| diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java |
| index 36ab8d9b07..209deec6bd 100644 |
| --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java |
| +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TopGroups.java |
| @@ -133,12 +133,12 @@ public class TopGroups<T> { |
| } else { |
| shardTopDocs = new TopFieldDocs[shardGroups.length]; |
| } |
| - float totalMaxScore = Float.MIN_VALUE; |
| + float totalMaxScore = Float.NaN; |
| |
| for(int groupIDX=0;groupIDX<numGroups;groupIDX++) { |
| final T groupValue = shardGroups[0].groups[groupIDX].groupValue; |
| //System.out.println(" merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues)); |
| - float maxScore = Float.MIN_VALUE; |
| + float maxScore = Float.NaN; |
| int totalHits = 0; |
| double scoreSum = 0.0; |
| for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) { |
| @@ -169,7 +169,7 @@ public class TopGroups<T> { |
| docSort.getSort(), |
| shardGroupDocs.maxScore); |
| } |
| - maxScore = Math.max(maxScore, shardGroupDocs.maxScore); |
| + maxScore = max(maxScore, shardGroupDocs.maxScore); |
| totalHits += shardGroupDocs.totalHits; |
| scoreSum += shardGroupDocs.score; |
| } |
| @@ -222,7 +222,7 @@ public class TopGroups<T> { |
| mergedScoreDocs, |
| groupValue, |
| shardGroups[0].groups[groupIDX].groupSortValues); |
| - totalMaxScore = Math.max(totalMaxScore, maxScore); |
| + totalMaxScore = max(totalMaxScore, maxScore); |
| } |
| |
| if (totalGroupCount != null) { |
| @@ -242,4 +242,11 @@ public class TopGroups<T> { |
| totalMaxScore); |
| } |
| } |
| + |
| + public static float max(float a, float b) { |
| + if (Float.isNaN(a)) { |
| + return Float.isNaN(b) ? Float.NaN : b; |
| + } |
| + return Float.isNaN(b) ? a : Math.max(a, b); |
| + } |
| } |
| diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java |
| new file mode 100644 |
| index 0000000000..194fcff7a2 |
| --- /dev/null |
| +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TopGroupsTest.java |
| @@ -0,0 +1,264 @@ |
| +package org.apache.lucene.search.grouping; |
| + |
| +import junit.framework.TestCase; |
| +import org.apache.lucene.search.ScoreDoc; |
| +import org.apache.lucene.search.Sort; |
| +import org.apache.lucene.search.SortField; |
| +import org.apache.lucene.util.LuceneTestCase; |
| + |
| +import java.io.IOException; |
| +import java.util.ArrayList; |
| +import java.util.List; |
| + |
| +public final class TopGroupsTest extends LuceneTestCase { |
| + public void testMax() { |
| + assertTrue(Float.isNaN(TopGroups.max(Float.NaN, Float.NaN))); |
| + assertEquals(2.0f, TopGroups.max(2.0f, Float.NaN), 0.1f); |
| + assertEquals(2.0f, TopGroups.max(Float.NaN, 2.0f), 0.1f); |
| + assertEquals(2.0f, TopGroups.max(1.0f, 2.0f), 0.1f); |
| + assertEquals(2.0f, TopGroups.max(2.0f, 1.0f), 0.1f); |
| + } |
| + |
| + @SuppressWarnings("unchecked") |
| + public void testMerge() throws IOException { |
| + //Given the following groups received from the shards |
| + TopGroups<Integer> shard1 = newTopGroup() |
| + .withGroup(newGroup().withValue(1).build()) |
| + .withGroup(newGroup().withValue(2).build()) |
| + .build(); |
| + |
| + TopGroups<Integer> shard2 = newTopGroup() |
| + .withGroup(newGroup().withValue(1) |
| + .withDoc(newScoreDoc().withId(5).withScore(2.0f).withShardId(1).build()) |
| + .withDoc(newScoreDoc().withId(1).withScore(1.0f).withShardId(1).build()) |
| + .build()) |
| + .withGroup(newGroup().withValue(2) |
| + .withDoc(newScoreDoc().withId(10).withScore(3.0f).withShardId(1).build()) |
| + .build()) |
| + .build(); |
| + |
| + TopGroups<Integer> shard3 = newTopGroup() |
| + .withGroup(newGroup().withValue(1) |
| + .withDoc(newScoreDoc().withId(1).withScore(4.0f).withShardId(2).build()) |
| + .build()) |
| + .withGroup(newGroup().withValue(2).build()) |
| + .build(); |
| + |
| + TopGroups<Integer> shard4 = newTopGroup() |
| + .withGroup(newGroup().withValue(1).build()) |
| + .withGroup(newGroup().withValue(2).build()) |
| + .build(); |
| + |
| + // When merging the groups |
| + TopGroups<Integer> merged = TopGroups.merge( |
| + new TopGroups[]{shard1, shard2, shard3, shard4}, // shardGroups |
| + Sort.RELEVANCE, // groupSort |
| + Sort.RELEVANCE, // docSort |
| + 0, // docOffset |
| + 2, // docTopN |
| + TopGroups.ScoreMergeMode.None // scoreMergeMode |
| + ); |
| + |
| + // Expect the following results |
| + assertNotNull(merged); |
| + assertEquals(4, merged.totalHitCount); |
| + assertEquals(4, merged.totalGroupedHitCount); |
| + assertEquals(4.0f, merged.maxScore, 0.1f); |
| + assertEquals(2, merged.groups.length); |
| + |
| + { |
| + GroupDocs<Integer> group1 = merged.groups[0]; |
| + assertGroupDocs(group1) |
| + .withValue(1) |
| + .withMaxScore(4.0f) |
| + .withNbDocs(2) |
| + .withTotalHits(3); |
| + |
| + ScoreDoc group1Doc1 = group1.scoreDocs[0]; |
| + assertScoreDoc(group1Doc1) |
| + .withId(1) |
| + .withShardId(2) |
| + .withScore(4.0f); |
| + |
| + ScoreDoc group1Doc2 = group1.scoreDocs[1]; |
| + assertScoreDoc(group1Doc2) |
| + .withId(5) |
| + .withShardId(1) |
| + .withScore(2.0f); |
| + } |
| + |
| + { |
| + GroupDocs<Integer> group2 = merged.groups[1]; |
| + assertGroupDocs(group2) |
| + .withValue(2) |
| + .withMaxScore(3.0f) |
| + .withNbDocs(1). |
| + withTotalHits(1); |
| + |
| + ScoreDoc group2Doc1 = group2.scoreDocs[0]; |
| + assertScoreDoc(group2Doc1) |
| + .withId(10) |
| + .withShardId(1) |
| + .withScore(3.0f); |
| + } |
| + } |
| + |
| + // -------------------------------------------- |
| + // Helpers |
| + // -------------------------------------------- |
| + |
| + private static TopGroupBuilder<Integer> newTopGroup() { |
| + return new TopGroupBuilder<Integer>(); |
| + } |
| + |
| + private static GroupDocsBuilder<Integer> newGroup() { |
| + return new GroupDocsBuilder<Integer>(); |
| + } |
| + |
| + private static ScoreDocBuilder newScoreDoc() { |
| + return new ScoreDocBuilder(); |
| + } |
| + |
| + private static <T> GroupDocsAssert<T> assertGroupDocs(GroupDocs<T> actual) { |
| + return new GroupDocsAssert<T>(actual); |
| + } |
| + |
| + private static ScoreDocAssert assertScoreDoc(ScoreDoc actual) { |
| + return new ScoreDocAssert(actual); |
| + } |
| + |
| + private static final class TopGroupBuilder<T> { |
| + private final List<GroupDocs<T>> groups = new ArrayList<GroupDocs<T>>(); |
| + |
| + TopGroupBuilder<T> withGroup(GroupDocs<T> group) { |
| + groups.add(group); |
| + return this; |
| + } |
| + |
| + @SuppressWarnings("unchecked") |
| + TopGroups<T> build() { |
| + int totalHitCount = 0; |
| + float maxScore = Float.NaN; |
| + for (GroupDocs<T> group : groups) { |
| + totalHitCount += group.totalHits; |
| + maxScore = TopGroups.max(maxScore, group.maxScore); |
| + } |
| + return new TopGroups<T>( |
| + new SortField[]{SortField.FIELD_SCORE}, // groupSort |
| + new SortField[]{SortField.FIELD_SCORE}, // withinGroupSort |
| + totalHitCount, // totalHitCount |
| + totalHitCount, // totalGroupedHitCount |
| + groups.toArray(new GroupDocs[groups.size()]), // groups |
| + maxScore // maxScore |
| + ); |
| + } |
| + } |
| + |
| + private static final class GroupDocsBuilder<T> { |
| + private T value; |
| + private final List<ScoreDoc> docs = new ArrayList<ScoreDoc>(); |
| + |
| + GroupDocsBuilder<T> withValue(T value) { |
| + this.value = value; |
| + return this; |
| + } |
| + |
| + GroupDocsBuilder<T> withDoc(ScoreDoc doc) { |
| + this.docs.add(doc); |
| + return this; |
| + } |
| + |
| + GroupDocs<T> build() { |
| + float maxScore = Float.NaN; |
| + int totalHits = 0; |
| + for (ScoreDoc doc : docs) { |
| + maxScore = TopGroups.max(maxScore, doc.score); |
| + totalHits++; |
| + } |
| + return new GroupDocs<T>( |
| + Float.NaN, // score |
| + maxScore, // maxScore |
| + totalHits, // totalHits |
| + docs.toArray(new ScoreDoc[docs.size()]), // scoreDocs |
| + value, // groupValue |
| + null // groupSortValue |
| + ); |
| + } |
| + } |
| + |
| + private static final class ScoreDocBuilder { |
| + private int id; |
| + private int shardId; |
| + private float score; |
| + |
| + ScoreDocBuilder withId(int id) { |
| + this.id = id; |
| + return this; |
| + } |
| + |
| + ScoreDocBuilder withShardId(int shardId) { |
| + this.shardId = shardId; |
| + return this; |
| + } |
| + |
| + ScoreDocBuilder withScore(float score) { |
| + this.score = score; |
| + return this; |
| + } |
| + |
| + ScoreDoc build() { |
| + return new ScoreDoc(id, score, shardId); |
| + } |
| + } |
| + |
| + private static final class GroupDocsAssert<T> { |
| + private final GroupDocs<T> actual; |
| + |
| + GroupDocsAssert(GroupDocs<T> actual) { |
| + this.actual = actual; |
| + } |
| + |
| + GroupDocsAssert<T> withValue(T expected) { |
| + assertEquals("Invalid group value", expected, actual.groupValue); |
| + return this; |
| + } |
| + |
| + GroupDocsAssert<T> withTotalHits(int expected) { |
| + assertEquals("Invalid number of total hits", expected, actual.totalHits); |
| + return this; |
| + } |
| + |
| + GroupDocsAssert<T> withMaxScore(float expected) { |
| + assertEquals("Invalid maximum score", expected, actual.maxScore, 0.1f); |
| + return this; |
| + } |
| + |
| + GroupDocsAssert<T> withNbDocs(int expected) { |
| + assertEquals("Invalid number of score docs", expected, actual.scoreDocs.length); |
| + return this; |
| + } |
| + } |
| + |
| + private static final class ScoreDocAssert { |
| + private final ScoreDoc actual; |
| + |
| + ScoreDocAssert(ScoreDoc actual) { |
| + this.actual = actual; |
| + } |
| + |
| + ScoreDocAssert withId(int expected) { |
| + assertEquals("Invalid doc ID", expected, actual.doc); |
| + return this; |
| + } |
| + |
| + ScoreDocAssert withScore(float expected) { |
| + assertEquals("Invalid score", expected, actual.score, 0.1f); |
| + return this; |
| + } |
| + |
| + ScoreDocAssert withShardId(int expected) { |
| + assertEquals("Invalid shard ID", expected, actual.shardIndex); |
| + return this; |
| + } |
| + } |
| +} |
| \ No newline at end of file |
| -- |
| 2.11.0 |
| |