blob: f1ce508e134ab39d944553204446bad206e8009a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.BytesRefFieldSource;
import org.apache.lucene.search.CachingCollector;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.apache.lucene.util.mutable.MutableValue;
import org.apache.lucene.util.mutable.MutableValueStr;
// TODO
// - should test relevance sort too
// - test null
// - test ties
// - test compound sort
public class TestGrouping extends LuceneTestCase {
public void testBasic() throws Exception {
String groupField = "author";
FieldType customType = new FieldType();
customType.setStored(true);
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy()));
// 0
Document doc = new Document();
addGroupField(doc, groupField, "author1");
doc.add(new TextField("content", "random text", Field.Store.YES));
doc.add(new Field("id", "1", customType));
w.addDocument(doc);
// 1
doc = new Document();
addGroupField(doc, groupField, "author1");
doc.add(new TextField("content", "some more random text", Field.Store.YES));
doc.add(new Field("id", "2", customType));
w.addDocument(doc);
// 2
doc = new Document();
addGroupField(doc, groupField, "author1");
doc.add(new TextField("content", "some more random textual data", Field.Store.YES));
doc.add(new Field("id", "3", customType));
w.addDocument(doc);
// 3
doc = new Document();
addGroupField(doc, groupField, "author2");
doc.add(new TextField("content", "some random text", Field.Store.YES));
doc.add(new Field("id", "4", customType));
w.addDocument(doc);
// 4
doc = new Document();
addGroupField(doc, groupField, "author3");
doc.add(new TextField("content", "some more random text", Field.Store.YES));
doc.add(new Field("id", "5", customType));
w.addDocument(doc);
// 5
doc = new Document();
addGroupField(doc, groupField, "author3");
doc.add(new TextField("content", "random", Field.Store.YES));
doc.add(new Field("id", "6", customType));
w.addDocument(doc);
// 6 -- no author field
doc = new Document();
doc.add(new TextField("content", "random word stuck in alot of other text", Field.Store.YES));
doc.add(new Field("id", "6", customType));
w.addDocument(doc);
IndexSearcher indexSearcher = newSearcher(w.getReader());
// This test relies on the fact that longer fields produce lower scores
indexSearcher.setSimilarity(new BM25Similarity());
w.close();
final Sort groupSort = Sort.RELEVANCE;
final FirstPassGroupingCollector<?> c1 = createRandomFirstPassCollector(groupField, groupSort, 10);
indexSearcher.search(new TermQuery(new Term("content", "random")), c1);
final TopGroupsCollector<?> c2 = createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true);
indexSearcher.search(new TermQuery(new Term("content", "random")), c2);
final TopGroups<?> groups = c2.getTopGroups(0);
assertFalse(Float.isNaN(groups.maxScore));
assertEquals(7, groups.totalHitCount);
assertEquals(7, groups.totalGroupedHitCount);
assertEquals(4, groups.groups.length);
// relevance order: 5, 0, 3, 4, 1, 2, 6
// the later a document is added the higher this docId
// value
GroupDocs<?> group = groups.groups[0];
compareGroupValue("author3", group);
assertEquals(2, group.scoreDocs.length);
assertEquals(5, group.scoreDocs[0].doc);
assertEquals(4, group.scoreDocs[1].doc);
assertTrue(group.scoreDocs[0].score > group.scoreDocs[1].score);
group = groups.groups[1];
compareGroupValue("author1", group);
assertEquals(3, group.scoreDocs.length);
assertEquals(0, group.scoreDocs[0].doc);
assertEquals(1, group.scoreDocs[1].doc);
assertEquals(2, group.scoreDocs[2].doc);
assertTrue(group.scoreDocs[0].score >= group.scoreDocs[1].score);
assertTrue(group.scoreDocs[1].score >= group.scoreDocs[2].score);
group = groups.groups[2];
compareGroupValue("author2", group);
assertEquals(1, group.scoreDocs.length);
assertEquals(3, group.scoreDocs[0].doc);
group = groups.groups[3];
compareGroupValue(null, group);
assertEquals(1, group.scoreDocs.length);
assertEquals(6, group.scoreDocs[0].doc);
indexSearcher.getIndexReader().close();
dir.close();
}
private void addGroupField(Document doc, String groupField, String value) {
doc.add(new SortedDocValuesField(groupField, new BytesRef(value)));
}
private FirstPassGroupingCollector<?> createRandomFirstPassCollector(String groupField, Sort groupSort, int topDocs) throws IOException {
if (random().nextBoolean()) {
ValueSource vs = new BytesRefFieldSource(groupField);
return new FirstPassGroupingCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), groupSort, topDocs);
} else {
return new FirstPassGroupingCollector<>(new TermGroupSelector(groupField), groupSort, topDocs);
}
}
private FirstPassGroupingCollector<?> createFirstPassCollector(String groupField, Sort groupSort, int topDocs, FirstPassGroupingCollector<?> firstPassGroupingCollector) throws IOException {
GroupSelector<?> selector = firstPassGroupingCollector.getGroupSelector();
if (TermGroupSelector.class.isAssignableFrom(selector.getClass())) {
ValueSource vs = new BytesRefFieldSource(groupField);
return new FirstPassGroupingCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), groupSort, topDocs);
} else {
return new FirstPassGroupingCollector<>(new TermGroupSelector(groupField), groupSort, topDocs);
}
}
@SuppressWarnings({"unchecked","rawtypes"})
private <T> TopGroupsCollector<T> createSecondPassCollector(FirstPassGroupingCollector firstPassGroupingCollector,
Sort groupSort,
Sort sortWithinGroup,
int groupOffset,
int maxDocsPerGroup,
boolean getMaxScores) throws IOException {
Collection<SearchGroup<T>> searchGroups = firstPassGroupingCollector.getTopGroups(groupOffset);
return new TopGroupsCollector<>(firstPassGroupingCollector.getGroupSelector(), searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
}
// Basically converts searchGroups from MutableValue to BytesRef if grouping by ValueSource
@SuppressWarnings("unchecked")
private TopGroupsCollector<?> createSecondPassCollector(FirstPassGroupingCollector<?> firstPassGroupingCollector,
String groupField,
Collection<SearchGroup<BytesRef>> searchGroups,
Sort groupSort,
Sort sortWithinGroup,
int maxDocsPerGroup,
boolean getMaxScores) throws IOException {
if (firstPassGroupingCollector.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) {
GroupSelector<BytesRef> selector = (GroupSelector<BytesRef>) firstPassGroupingCollector.getGroupSelector();
return new TopGroupsCollector<>(selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
} else {
ValueSource vs = new BytesRefFieldSource(groupField);
List<SearchGroup<MutableValue>> mvalSearchGroups = new ArrayList<>(searchGroups.size());
for (SearchGroup<BytesRef> mergedTopGroup : searchGroups) {
SearchGroup<MutableValue> sg = new SearchGroup<>();
MutableValueStr groupValue = new MutableValueStr();
if (mergedTopGroup.groupValue != null) {
groupValue.value.copyBytes(mergedTopGroup.groupValue);
} else {
groupValue.exists = false;
}
sg.groupValue = groupValue;
sg.sortValues = mergedTopGroup.sortValues;
mvalSearchGroups.add(sg);
}
ValueSourceGroupSelector selector = new ValueSourceGroupSelector(vs, new HashMap<>());
return new TopGroupsCollector<>(selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
}
}
private AllGroupsCollector<?> createAllGroupsCollector(FirstPassGroupingCollector<?> firstPassGroupingCollector,
String groupField) {
return new AllGroupsCollector<>(firstPassGroupingCollector.getGroupSelector());
}
private void compareGroupValue(String expected, GroupDocs<?> group) {
if (expected == null) {
if (group.groupValue == null) {
return;
} else if (group.groupValue.getClass().isAssignableFrom(MutableValueStr.class)) {
return;
} else if (((BytesRef) group.groupValue).length == 0) {
return;
}
fail();
}
if (group.groupValue.getClass().isAssignableFrom(BytesRef.class)) {
assertEquals(new BytesRef(expected), group.groupValue);
} else if (group.groupValue.getClass().isAssignableFrom(MutableValueStr.class)) {
MutableValueStr v = new MutableValueStr();
v.value.copyChars(expected);
assertEquals(v, group.groupValue);
} else {
fail();
}
}
private Collection<SearchGroup<BytesRef>> getSearchGroups(FirstPassGroupingCollector<?> c, int groupOffset) throws IOException {
if (TermGroupSelector.class.isAssignableFrom(c.getGroupSelector().getClass())) {
FirstPassGroupingCollector<BytesRef> collector = (FirstPassGroupingCollector<BytesRef>) c;
return collector.getTopGroups(groupOffset);
} else if (ValueSourceGroupSelector.class.isAssignableFrom(c.getGroupSelector().getClass())) {
FirstPassGroupingCollector<MutableValue> collector = (FirstPassGroupingCollector<MutableValue>) c;
Collection<SearchGroup<MutableValue>> mutableValueGroups = collector.getTopGroups(groupOffset);
if (mutableValueGroups == null) {
return null;
}
List<SearchGroup<BytesRef>> groups = new ArrayList<>(mutableValueGroups.size());
for (SearchGroup<MutableValue> mutableValueGroup : mutableValueGroups) {
SearchGroup<BytesRef> sg = new SearchGroup<>();
sg.groupValue = mutableValueGroup.groupValue.exists() ? ((MutableValueStr) mutableValueGroup.groupValue).value.get() : null;
sg.sortValues = mutableValueGroup.sortValues;
groups.add(sg);
}
return groups;
}
fail();
return null;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private TopGroups<BytesRef> getTopGroups(TopGroupsCollector c, int withinGroupOffset) {
if (c.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) {
TopGroupsCollector<BytesRef> collector = (TopGroupsCollector<BytesRef>) c;
return collector.getTopGroups(withinGroupOffset);
} else if (c.getGroupSelector().getClass().isAssignableFrom(ValueSourceGroupSelector.class)) {
TopGroupsCollector<MutableValue> collector = (TopGroupsCollector<MutableValue>) c;
TopGroups<MutableValue> mvalTopGroups = collector.getTopGroups(withinGroupOffset);
List<GroupDocs<BytesRef>> groups = new ArrayList<>(mvalTopGroups.groups.length);
for (GroupDocs<MutableValue> mvalGd : mvalTopGroups.groups) {
BytesRef groupValue = mvalGd.groupValue.exists() ? ((MutableValueStr) mvalGd.groupValue).value.get() : null;
groups.add(new GroupDocs<>(Float.NaN, mvalGd.maxScore, mvalGd.totalHits, mvalGd.scoreDocs, groupValue, mvalGd.groupSortValues));
}
// NOTE: currenlty using diamond operator on MergedIterator (without explicit Term class) causes
// errors on Eclipse Compiler (ecj) used for javadoc lint
return new TopGroups<BytesRef>(mvalTopGroups.groupSort, mvalTopGroups.withinGroupSort, mvalTopGroups.totalHitCount, mvalTopGroups.totalGroupedHitCount, groups.toArray(new GroupDocs[groups.size()]), Float.NaN);
}
fail();
return null;
}
private static class GroupDoc {
final int id;
final BytesRef group;
final BytesRef sort1;
final BytesRef sort2;
// content must be "realN ..."
final String content;
float score;
float score2;
public GroupDoc(int id, BytesRef group, BytesRef sort1, BytesRef sort2, String content) {
this.id = id;
this.group = group;
this.sort1 = sort1;
this.sort2 = sort2;
this.content = content;
}
}
private Sort getRandomSort() {
final List<SortField> sortFields = new ArrayList<>();
if (random().nextInt(7) == 2) {
sortFields.add(SortField.FIELD_SCORE);
} else {
if (random().nextBoolean()) {
if (random().nextBoolean()) {
sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
} else {
sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
}
} else if (random().nextBoolean()) {
sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
}
}
// Break ties:
sortFields.add(new SortField("id", SortField.Type.INT));
return new Sort(sortFields.toArray(new SortField[sortFields.size()]));
}
private Comparator<GroupDoc> getComparator(Sort sort) {
final SortField[] sortFields = sort.getSort();
return new Comparator<GroupDoc>() {
@Override
public int compare(GroupDoc d1, GroupDoc d2) {
for(SortField sf : sortFields) {
final int cmp;
if (sf.getType() == SortField.Type.SCORE) {
if (d1.score > d2.score) {
cmp = -1;
} else if (d1.score < d2.score) {
cmp = 1;
} else {
cmp = 0;
}
} else if (sf.getField().equals("sort1")) {
cmp = d1.sort1.compareTo(d2.sort1);
} else if (sf.getField().equals("sort2")) {
cmp = d1.sort2.compareTo(d2.sort2);
} else {
assertEquals(sf.getField(), "id");
cmp = d1.id - d2.id;
}
if (cmp != 0) {
return sf.getReverse() ? -cmp : cmp;
}
}
// Our sort always fully tie breaks:
fail();
return 0;
}
};
}
@SuppressWarnings({"unchecked","rawtypes"})
private Comparable<?>[] fillFields(GroupDoc d, Sort sort) {
final SortField[] sortFields = sort.getSort();
final Comparable<?>[] fields = new Comparable[sortFields.length];
for(int fieldIDX=0;fieldIDX<sortFields.length;fieldIDX++) {
final Comparable<?> c;
final SortField sf = sortFields[fieldIDX];
if (sf.getType() == SortField.Type.SCORE) {
c = d.score;
} else if (sf.getField().equals("sort1")) {
c = d.sort1;
} else if (sf.getField().equals("sort2")) {
c = d.sort2;
} else {
assertEquals("id", sf.getField());
c = d.id;
}
fields[fieldIDX] = c;
}
return fields;
}
private String groupToString(BytesRef b) {
if (b == null) {
return "null";
} else {
return b.utf8ToString();
}
}
private TopGroups<BytesRef> slowGrouping(GroupDoc[] groupDocs,
String searchTerm,
boolean getMaxScores,
boolean doAllGroups,
Sort groupSort,
Sort docSort,
int topNGroups,
int docsPerGroup,
int groupOffset,
int docOffset) {
final Comparator<GroupDoc> groupSortComp = getComparator(groupSort);
Arrays.sort(groupDocs, groupSortComp);
final HashMap<BytesRef,List<GroupDoc>> groups = new HashMap<>();
final List<BytesRef> sortedGroups = new ArrayList<>();
final List<Comparable<?>[]> sortedGroupFields = new ArrayList<>();
int totalHitCount = 0;
Set<BytesRef> knownGroups = new HashSet<>();
//System.out.println("TEST: slowGrouping");
for(GroupDoc d : groupDocs) {
// TODO: would be better to filter by searchTerm before sorting!
if (!d.content.startsWith(searchTerm)) {
continue;
}
totalHitCount++;
//System.out.println(" match id=" + d.id + " score=" + d.score);
if (doAllGroups) {
if (!knownGroups.contains(d.group)) {
knownGroups.add(d.group);
//System.out.println(" add group=" + groupToString(d.group));
}
}
List<GroupDoc> l = groups.get(d.group);
if (l == null) {
//System.out.println(" add sortedGroup=" + groupToString(d.group));
sortedGroups.add(d.group);
sortedGroupFields.add(fillFields(d, groupSort));
l = new ArrayList<>();
groups.put(d.group, l);
}
l.add(d);
}
if (groupOffset >= sortedGroups.size()) {
// slice is out of bounds
return null;
}
final int limit = Math.min(groupOffset + topNGroups, groups.size());
final Comparator<GroupDoc> docSortComp = getComparator(docSort);
@SuppressWarnings({"unchecked","rawtypes"})
final GroupDocs<BytesRef>[] result = new GroupDocs[limit-groupOffset];
int totalGroupedHitCount = 0;
for(int idx=groupOffset;idx < limit;idx++) {
final BytesRef group = sortedGroups.get(idx);
final List<GroupDoc> docs = groups.get(group);
totalGroupedHitCount += docs.size();
Collections.sort(docs, docSortComp);
final ScoreDoc[] hits;
if (docs.size() > docOffset) {
final int docIDXLimit = Math.min(docOffset + docsPerGroup, docs.size());
hits = new ScoreDoc[docIDXLimit - docOffset];
for(int docIDX=docOffset; docIDX < docIDXLimit; docIDX++) {
final GroupDoc d = docs.get(docIDX);
final FieldDoc fd;
fd = new FieldDoc(d.id, Float.NaN, fillFields(d, docSort));
hits[docIDX-docOffset] = fd;
}
} else {
hits = new ScoreDoc[0];
}
result[idx-groupOffset] = new GroupDocs<>(Float.NaN,
0.0f,
new TotalHits(docs.size(), TotalHits.Relation.EQUAL_TO),
hits,
group,
sortedGroupFields.get(idx));
}
if (doAllGroups) {
return new TopGroups<>(
new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN),
knownGroups.size()
);
} else {
return new TopGroups<>(groupSort.getSort(), docSort.getSort(), totalHitCount, totalGroupedHitCount, result, Float.NaN);
}
}
private DirectoryReader getDocBlockReader(Directory dir, GroupDoc[] groupDocs) throws IOException {
// Coalesce by group, but in random order:
Collections.shuffle(Arrays.asList(groupDocs), random());
final Map<BytesRef,List<GroupDoc>> groupMap = new HashMap<>();
final List<BytesRef> groupValues = new ArrayList<>();
for(GroupDoc groupDoc : groupDocs) {
if (!groupMap.containsKey(groupDoc.group)) {
groupValues.add(groupDoc.group);
groupMap.put(groupDoc.group, new ArrayList<GroupDoc>());
}
groupMap.get(groupDoc.group).add(groupDoc);
}
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())));
final List<List<Document>> updateDocs = new ArrayList<>();
FieldType groupEndType = new FieldType(StringField.TYPE_NOT_STORED);
groupEndType.setIndexOptions(IndexOptions.DOCS);
groupEndType.setOmitNorms(true);
//System.out.println("TEST: index groups");
for(BytesRef group : groupValues) {
final List<Document> docs = new ArrayList<>();
//System.out.println("TEST: group=" + (group == null ? "null" : group.utf8ToString()));
for(GroupDoc groupValue : groupMap.get(group)) {
Document doc = new Document();
docs.add(doc);
if (groupValue.group != null) {
doc.add(newStringField("group", groupValue.group.utf8ToString(), Field.Store.YES));
doc.add(new SortedDocValuesField("group", BytesRef.deepCopyOf(groupValue.group)));
}
doc.add(newStringField("sort1", groupValue.sort1.utf8ToString(), Field.Store.NO));
doc.add(new SortedDocValuesField("sort1", BytesRef.deepCopyOf(groupValue.sort1)));
doc.add(newStringField("sort2", groupValue.sort2.utf8ToString(), Field.Store.NO));
doc.add(new SortedDocValuesField("sort2", BytesRef.deepCopyOf(groupValue.sort2)));
doc.add(new NumericDocValuesField("id", groupValue.id));
doc.add(newTextField("content", groupValue.content, Field.Store.NO));
//System.out.println("TEST: doc content=" + groupValue.content + " group=" + (groupValue.group == null ? "null" : groupValue.group.utf8ToString()) + " sort1=" + groupValue.sort1.utf8ToString() + " id=" + groupValue.id);
}
// So we can pull filter marking last doc in block:
final Field groupEnd = newField("groupend", "x", groupEndType);
docs.get(docs.size()-1).add(groupEnd);
// Add as a doc block:
w.addDocuments(docs);
if (group != null && random().nextInt(7) == 4) {
updateDocs.add(docs);
}
}
for(List<Document> docs : updateDocs) {
// Just replaces docs w/ same docs:
w.updateDocuments(new Term("group", docs.get(0).get("group")), docs);
}
final DirectoryReader r = w.getReader();
w.close();
return r;
}
private static class ShardState {
public final ShardSearcher[] subSearchers;
public final int[] docStarts;
public ShardState(IndexSearcher s) {
final IndexReaderContext ctx = s.getTopReaderContext();
final List<LeafReaderContext> leaves = ctx.leaves();
subSearchers = new ShardSearcher[leaves.size()];
for(int searcherIDX=0;searcherIDX<subSearchers.length;searcherIDX++) {
subSearchers[searcherIDX] = new ShardSearcher(leaves.get(searcherIDX), ctx);
}
docStarts = new int[subSearchers.length];
for(int subIDX=0;subIDX<docStarts.length;subIDX++) {
docStarts[subIDX] = leaves.get(subIDX).docBase;
//System.out.println("docStarts[" + subIDX + "]=" + docStarts[subIDX]);
}
}
}
public void testRandom() throws Exception {
int numberOfRuns = atLeast(1);
for (int iter=0; iter<numberOfRuns; iter++) {
if (VERBOSE) {
System.out.println("TEST: iter=" + iter);
}
final int numDocs = atLeast(100);
//final int numDocs = _TestUtil.nextInt(random, 5, 20);
final int numGroups = TestUtil.nextInt(random(), 1, numDocs);
if (VERBOSE) {
System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups);
}
final List<BytesRef> groups = new ArrayList<>();
for(int i=0;i<numGroups;i++) {
String randomValue;
do {
// B/c of DV based impl we can't see the difference between an empty string and a null value.
// For that reason we don't generate empty string
// groups.
randomValue = TestUtil.randomRealisticUnicodeString(random());
//randomValue = TestUtil.randomSimpleString(random());
} while ("".equals(randomValue));
groups.add(new BytesRef(randomValue));
}
final String[] contentStrings = new String[TestUtil.nextInt(random(), 2, 20)];
if (VERBOSE) {
System.out.println("TEST: create fake content");
}
for(int contentIDX=0;contentIDX<contentStrings.length;contentIDX++) {
final StringBuilder sb = new StringBuilder();
sb.append("real").append(random().nextInt(3)).append(' ');
final int fakeCount = random().nextInt(10);
for(int fakeIDX=0;fakeIDX<fakeCount;fakeIDX++) {
sb.append("fake ");
}
contentStrings[contentIDX] = sb.toString();
if (VERBOSE) {
System.out.println(" content=" + sb.toString());
}
}
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())));
Document doc = new Document();
Document docNoGroup = new Document();
Field idvGroupField = new SortedDocValuesField("group", new BytesRef());
doc.add(idvGroupField);
docNoGroup.add(idvGroupField);
Field group = newStringField("group", "", Field.Store.NO);
doc.add(group);
Field sort1 = new SortedDocValuesField("sort1", new BytesRef());
doc.add(sort1);
docNoGroup.add(sort1);
Field sort2 = new SortedDocValuesField("sort2", new BytesRef());
doc.add(sort2);
docNoGroup.add(sort2);
Field content = newTextField("content", "", Field.Store.NO);
doc.add(content);
docNoGroup.add(content);
NumericDocValuesField idDV = new NumericDocValuesField("id", 0);
doc.add(idDV);
docNoGroup.add(idDV);
final GroupDoc[] groupDocs = new GroupDoc[numDocs];
for(int i=0;i<numDocs;i++) {
final BytesRef groupValue;
if (random().nextInt(24) == 17) {
// So we test the "doc doesn't have the group'd
// field" case:
groupValue = null;
} else {
groupValue = groups.get(random().nextInt(groups.size()));
}
final GroupDoc groupDoc = new GroupDoc(i,
groupValue,
groups.get(random().nextInt(groups.size())),
groups.get(random().nextInt(groups.size())),
contentStrings[random().nextInt(contentStrings.length)]);
if (VERBOSE) {
System.out.println(" doc content=" + groupDoc.content + " id=" + i + " group=" + (groupDoc.group == null ? "null" : groupDoc.group.utf8ToString()) + " sort1=" + groupDoc.sort1.utf8ToString() + " sort2=" + groupDoc.sort2.utf8ToString());
}
groupDocs[i] = groupDoc;
if (groupDoc.group != null) {
group.setStringValue(groupDoc.group.utf8ToString());
idvGroupField.setBytesValue(BytesRef.deepCopyOf(groupDoc.group));
} else {
// TODO: not true
// Must explicitly set empty string, else eg if
// the segment has all docs missing the field then
// we get null back instead of empty BytesRef:
idvGroupField.setBytesValue(new BytesRef());
}
sort1.setBytesValue(BytesRef.deepCopyOf(groupDoc.sort1));
sort2.setBytesValue(BytesRef.deepCopyOf(groupDoc.sort2));
content.setStringValue(groupDoc.content);
idDV.setLongValue(groupDoc.id);
if (groupDoc.group == null) {
w.addDocument(docNoGroup);
} else {
w.addDocument(doc);
}
}
final GroupDoc[] groupDocsByID = new GroupDoc[groupDocs.length];
System.arraycopy(groupDocs, 0, groupDocsByID, 0, groupDocs.length);
final DirectoryReader r = w.getReader();
w.close();
NumericDocValues values = MultiDocValues.getNumericValues(r, "id");
int[] docIDToID = new int[r.maxDoc()];
for(int i=0;i<r.maxDoc();i++) {
assertEquals(i, values.nextDoc());
docIDToID[i] = (int) values.longValue();
}
DirectoryReader rBlocks = null;
Directory dirBlocks = null;
final IndexSearcher s = newSearcher(r);
// This test relies on the fact that longer fields produce lower scores
s.setSimilarity(new BM25Similarity());
if (VERBOSE) {
System.out.println("\nTEST: searcher=" + s);
}
final ShardState shards = new ShardState(s);
Set<Integer> seenIDs = new HashSet<>();
for(int contentID=0;contentID<3;contentID++) {
final ScoreDoc[] hits = s.search(new TermQuery(new Term("content", "real"+contentID)), numDocs).scoreDocs;
for(ScoreDoc hit : hits) {
int idValue = docIDToID[hit.doc];
final GroupDoc gd = groupDocs[idValue];
seenIDs.add(idValue);
assertTrue(gd.score == 0.0);
gd.score = hit.score;
assertEquals(gd.id, idValue);
}
}
// make sure all groups were seen across the hits
assertEquals(groupDocs.length, seenIDs.size());
for(GroupDoc gd : groupDocs) {
assertTrue(Float.isFinite(gd.score));
assertTrue(gd.score >= 0.0);
}
// Build 2nd index, where docs are added in blocks by
// group, so we can use single pass collector
dirBlocks = newDirectory();
rBlocks = getDocBlockReader(dirBlocks, groupDocs);
final Query lastDocInBlock = new TermQuery(new Term("groupend", "x"));
final IndexSearcher sBlocks = newSearcher(rBlocks);
// This test relies on the fact that longer fields produce lower scores
sBlocks.setSimilarity(new BM25Similarity());
final ShardState shardsBlocks = new ShardState(sBlocks);
// ReaderBlocks only increases maxDoc() vs reader, which
// means a monotonic shift in scores, so we can
// reliably remap them w/ Map:
final Map<String,Map<Float,Float>> scoreMap = new HashMap<>();
values = MultiDocValues.getNumericValues(rBlocks, "id");
assertNotNull(values);
int[] docIDToIDBlocks = new int[rBlocks.maxDoc()];
for(int i=0;i<rBlocks.maxDoc();i++) {
assertEquals(i, values.nextDoc());
docIDToIDBlocks[i] = (int) values.longValue();
}
// Tricky: must separately set .score2, because the doc
// block index was created with possible deletions!
//System.out.println("fixup score2");
for(int contentID=0;contentID<3;contentID++) {
//System.out.println(" term=real" + contentID);
final Map<Float,Float> termScoreMap = new HashMap<>();
scoreMap.put("real"+contentID, termScoreMap);
//System.out.println("term=real" + contentID + " dfold=" + s.docFreq(new Term("content", "real"+contentID)) +
//" dfnew=" + sBlocks.docFreq(new Term("content", "real"+contentID)));
final ScoreDoc[] hits = sBlocks.search(new TermQuery(new Term("content", "real"+contentID)), numDocs).scoreDocs;
for(ScoreDoc hit : hits) {
final GroupDoc gd = groupDocsByID[docIDToIDBlocks[hit.doc]];
assertTrue(gd.score2 == 0.0);
gd.score2 = hit.score;
assertEquals(gd.id, docIDToIDBlocks[hit.doc]);
//System.out.println(" score=" + gd.score + " score2=" + hit.score + " id=" + docIDToIDBlocks[hit.doc]);
termScoreMap.put(gd.score, gd.score2);
}
}
for(int searchIter=0;searchIter<100;searchIter++) {
if (VERBOSE) {
System.out.println("\nTEST: searchIter=" + searchIter);
}
final String searchTerm = "real" + random().nextInt(3);
final boolean getMaxScores = random().nextBoolean();
final Sort groupSort = getRandomSort();
//final Sort groupSort = new Sort(new SortField[] {new SortField("sort1", SortField.STRING), new SortField("id", SortField.INT)});
final Sort docSort = getRandomSort();
final int topNGroups = TestUtil.nextInt(random(), 1, 30);
//final int topNGroups = 10;
final int docsPerGroup = TestUtil.nextInt(random(), 1, 50);
final int groupOffset = TestUtil.nextInt(random(), 0, (topNGroups - 1) / 2);
//final int groupOffset = 0;
final int docOffset = TestUtil.nextInt(random(), 0, docsPerGroup - 1);
//final int docOffset = 0;
final boolean doCache = random().nextBoolean();
final boolean doAllGroups = random().nextBoolean();
if (VERBOSE) {
System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " dF=" + r.docFreq(new Term("content", searchTerm)) +" dFBlock=" + rBlocks.docFreq(new Term("content", searchTerm)) + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getMaxScores=" + getMaxScores);
}
String groupField = "group";
if (VERBOSE) {
System.out.println(" groupField=" + groupField);
}
final FirstPassGroupingCollector<?> c1 = createRandomFirstPassCollector(groupField, groupSort, groupOffset+topNGroups);
final CachingCollector cCache;
final Collector c;
final AllGroupsCollector<?> allGroupsCollector;
if (doAllGroups) {
allGroupsCollector = createAllGroupsCollector(c1, groupField);
} else {
allGroupsCollector = null;
}
final boolean useWrappingCollector = random().nextBoolean();
if (doCache) {
final double maxCacheMB = random().nextDouble();
if (VERBOSE) {
System.out.println("TEST: maxCacheMB=" + maxCacheMB);
}
if (useWrappingCollector) {
if (doAllGroups) {
cCache = CachingCollector.create(c1, true, maxCacheMB);
c = MultiCollector.wrap(cCache, allGroupsCollector);
} else {
c = cCache = CachingCollector.create(c1, true, maxCacheMB);
}
} else {
// Collect only into cache, then replay multiple times:
c = cCache = CachingCollector.create(true, maxCacheMB);
}
} else {
cCache = null;
if (doAllGroups) {
c = MultiCollector.wrap(c1, allGroupsCollector);
} else {
c = c1;
}
}
// Search top reader:
final Query query = new TermQuery(new Term("content", searchTerm));
s.search(query, c);
if (doCache && !useWrappingCollector) {
if (cCache.isCached()) {
// Replay for first-pass grouping
cCache.replay(c1);
if (doAllGroups) {
// Replay for all groups:
cCache.replay(allGroupsCollector);
}
} else {
// Replay by re-running search:
s.search(query, c1);
if (doAllGroups) {
s.search(query, allGroupsCollector);
}
}
}
// Get 1st pass top groups
final Collection<SearchGroup<BytesRef>> topGroups = getSearchGroups(c1, groupOffset);
final TopGroups<BytesRef> groupsResult;
if (VERBOSE) {
System.out.println("TEST: first pass topGroups");
if (topGroups == null) {
System.out.println(" null");
} else {
for (SearchGroup<BytesRef> searchGroup : topGroups) {
System.out.println(" " + (searchGroup.groupValue == null ? "null" : searchGroup.groupValue) + ": " + Arrays.deepToString(searchGroup.sortValues));
}
}
}
// Get 1st pass top groups using shards
final TopGroups<BytesRef> topGroupsShards = searchShards(s, shards.subSearchers, query, groupSort, docSort,
groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, true, true);
final TopGroupsCollector<?> c2;
if (topGroups != null) {
if (VERBOSE) {
System.out.println("TEST: topGroups");
for (SearchGroup<BytesRef> searchGroup : topGroups) {
System.out.println(" " + (searchGroup.groupValue == null ? "null" : searchGroup.groupValue.utf8ToString()) + ": " + Arrays.deepToString(searchGroup.sortValues));
}
}
c2 = createSecondPassCollector(c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getMaxScores);
if (doCache) {
if (cCache.isCached()) {
if (VERBOSE) {
System.out.println("TEST: cache is intact");
}
cCache.replay(c2);
} else {
if (VERBOSE) {
System.out.println("TEST: cache was too large");
}
s.search(query, c2);
}
} else {
s.search(query, c2);
}
if (doAllGroups) {
TopGroups<BytesRef> tempTopGroups = getTopGroups(c2, docOffset);
groupsResult = new TopGroups<>(tempTopGroups, allGroupsCollector.getGroupCount());
} else {
groupsResult = getTopGroups(c2, docOffset);
}
} else {
c2 = null;
groupsResult = null;
if (VERBOSE) {
System.out.println("TEST: no results");
}
}
final TopGroups<BytesRef> expectedGroups = slowGrouping(groupDocs, searchTerm, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset);
if (VERBOSE) {
if (expectedGroups == null) {
System.out.println("TEST: no expected groups");
} else {
System.out.println("TEST: expected groups totalGroupedHitCount=" + expectedGroups.totalGroupedHitCount);
for(GroupDocs<BytesRef> gd : expectedGroups.groups) {
System.out.println(" group=" + (gd.groupValue == null ? "null" : gd.groupValue) + " totalHits=" + gd.totalHits.value + " scoreDocs.len=" + gd.scoreDocs.length);
for(ScoreDoc sd : gd.scoreDocs) {
System.out.println(" id=" + sd.doc + " score=" + sd.score);
}
}
}
if (groupsResult == null) {
System.out.println("TEST: no matched groups");
} else {
System.out.println("TEST: matched groups totalGroupedHitCount=" + groupsResult.totalGroupedHitCount);
for(GroupDocs<BytesRef> gd : groupsResult.groups) {
System.out.println(" group=" + (gd.groupValue == null ? "null" : gd.groupValue) + " totalHits=" + gd.totalHits.value);
for(ScoreDoc sd : gd.scoreDocs) {
System.out.println(" id=" + docIDToID[sd.doc] + " score=" + sd.score);
}
}
if (searchIter == 14) {
for(int docIDX=0;docIDX<s.getIndexReader().maxDoc();docIDX++) {
System.out.println("ID=" + docIDToID[docIDX] + " explain=" + s.explain(query, docIDX));
}
}
}
if (topGroupsShards == null) {
System.out.println("TEST: no matched-merged groups");
} else {
System.out.println("TEST: matched-merged groups totalGroupedHitCount=" + topGroupsShards.totalGroupedHitCount);
for(GroupDocs<BytesRef> gd : topGroupsShards.groups) {
System.out.println(" group=" + (gd.groupValue == null ? "null" : gd.groupValue) + " totalHits=" + gd.totalHits.value);
for(ScoreDoc sd : gd.scoreDocs) {
System.out.println(" id=" + docIDToID[sd.doc] + " score=" + sd.score);
}
}
}
}
assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true);
// Confirm merged shards match:
assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, true);
if (topGroupsShards != null) {
verifyShards(shards.docStarts, topGroupsShards);
}
final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups,
groupSort.needsScores() || docSort.needsScores(), sBlocks.createWeight(sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
final AllGroupsCollector<BytesRef> allGroupsCollector2;
final Collector c4;
if (doAllGroups) {
// NOTE: must be "group" and not "group_dv"
// (groupField) because we didn't index doc
// values in the block index:
allGroupsCollector2 = new AllGroupsCollector<>(new TermGroupSelector("group"));
c4 = MultiCollector.wrap(c3, allGroupsCollector2);
} else {
allGroupsCollector2 = null;
c4 = c3;
}
// Get block grouping result:
sBlocks.search(query, c4);
@SuppressWarnings({"unchecked","rawtypes"})
final TopGroups<BytesRef> tempTopGroupsBlocks = (TopGroups<BytesRef>) c3.getTopGroups(docSort, groupOffset, docOffset, docOffset+docsPerGroup);
final TopGroups<BytesRef> groupsResultBlocks;
if (doAllGroups && tempTopGroupsBlocks != null) {
assertEquals((int) tempTopGroupsBlocks.totalGroupCount, allGroupsCollector2.getGroupCount());
groupsResultBlocks = new TopGroups<>(tempTopGroupsBlocks, allGroupsCollector2.getGroupCount());
} else {
groupsResultBlocks = tempTopGroupsBlocks;
}
if (VERBOSE) {
if (groupsResultBlocks == null) {
System.out.println("TEST: no block groups");
} else {
System.out.println("TEST: block groups totalGroupedHitCount=" + groupsResultBlocks.totalGroupedHitCount);
boolean first = true;
for(GroupDocs<BytesRef> gd : groupsResultBlocks.groups) {
System.out.println(" group=" + (gd.groupValue == null ? "null" : gd.groupValue.utf8ToString()) + " totalHits=" + gd.totalHits.value);
for(ScoreDoc sd : gd.scoreDocs) {
System.out.println(" id=" + docIDToIDBlocks[sd.doc] + " score=" + sd.score);
if (first) {
System.out.println("explain: " + sBlocks.explain(query, sd.doc));
first = false;
}
}
}
}
}
// Get shard'd block grouping result:
final TopGroups<BytesRef> topGroupsBlockShards = searchShards(sBlocks, shardsBlocks.subSearchers, query,
groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, false, false);
if (expectedGroups != null) {
// Fixup scores for reader2
for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
for(ScoreDoc hit : groupDocsHits.scoreDocs) {
final GroupDoc gd = groupDocsByID[hit.doc];
assertEquals(gd.id, hit.doc);
//System.out.println("fixup score " + hit.score + " to " + gd.score2 + " vs " + gd.score);
hit.score = gd.score2;
}
}
final SortField[] sortFields = groupSort.getSort();
final Map<Float,Float> termScoreMap = scoreMap.get(searchTerm);
for(int groupSortIDX=0;groupSortIDX<sortFields.length;groupSortIDX++) {
if (sortFields[groupSortIDX].getType() == SortField.Type.SCORE) {
for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
if (groupDocsHits.groupSortValues != null) {
//System.out.println("remap " + groupDocsHits.groupSortValues[groupSortIDX] + " to " + termScoreMap.get(groupDocsHits.groupSortValues[groupSortIDX]));
groupDocsHits.groupSortValues[groupSortIDX] = termScoreMap.get(groupDocsHits.groupSortValues[groupSortIDX]);
assertNotNull(groupDocsHits.groupSortValues[groupSortIDX]);
}
}
}
}
final SortField[] docSortFields = docSort.getSort();
for(int docSortIDX=0;docSortIDX<docSortFields.length;docSortIDX++) {
if (docSortFields[docSortIDX].getType() == SortField.Type.SCORE) {
for (GroupDocs<?> groupDocsHits : expectedGroups.groups) {
for(ScoreDoc _hit : groupDocsHits.scoreDocs) {
FieldDoc hit = (FieldDoc) _hit;
if (hit.fields != null) {
hit.fields[docSortIDX] = termScoreMap.get(hit.fields[docSortIDX]);
assertNotNull(hit.fields[docSortIDX]);
}
}
}
}
}
}
assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, false);
assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, false);
}
r.close();
dir.close();
rBlocks.close();
dirBlocks.close();
}
}
private void verifyShards(int[] docStarts, TopGroups<BytesRef> topGroups) {
for(GroupDocs<?> group : topGroups.groups) {
for(int hitIDX=0;hitIDX<group.scoreDocs.length;hitIDX++) {
final ScoreDoc sd = group.scoreDocs[hitIDX];
assertEquals("doc=" + sd.doc + " wrong shard",
ReaderUtil.subIndex(sd.doc, docStarts),
sd.shardIndex);
}
}
}
private TopGroups<BytesRef> searchShards(IndexSearcher topSearcher, ShardSearcher[] subSearchers, Query query, Sort groupSort, Sort docSort, int groupOffset, int topNGroups, int docOffset,
int topNDocs, boolean getMaxScores, boolean canUseIDV, boolean preFlex) throws Exception {
// TODO: swap in caching, all groups collector hereassertEquals(expected.totalHitCount, actual.totalHitCount);
// too...
if (VERBOSE) {
System.out.println("TEST: " + subSearchers.length + " shards: " + Arrays.toString(subSearchers) + " canUseIDV=" + canUseIDV);
}
// Run 1st pass collector to get top groups per shard
final Weight w = topSearcher.createWeight(topSearcher.rewrite(query), groupSort.needsScores() || docSort.needsScores() || getMaxScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1);
final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>();
List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>();
FirstPassGroupingCollector<?> firstPassCollector = null;
boolean shardsCanUseIDV = canUseIDV;
String groupField = "group";
for(int shardIDX=0;shardIDX<subSearchers.length;shardIDX++) {
// First shard determines whether we use IDV or not;
// all other shards match that:
if (firstPassCollector == null) {
firstPassCollector = createRandomFirstPassCollector(groupField, groupSort, groupOffset + topNGroups);
} else {
firstPassCollector = createFirstPassCollector(groupField, groupSort, groupOffset + topNGroups, firstPassCollector);
}
if (VERBOSE) {
System.out.println(" shard=" + shardIDX + " groupField=" + groupField);
System.out.println(" 1st pass collector=" + firstPassCollector);
}
firstPassGroupingCollectors.add(firstPassCollector);
subSearchers[shardIDX].search(w, firstPassCollector);
final Collection<SearchGroup<BytesRef>> topGroups = getSearchGroups(firstPassCollector, 0);
if (topGroups != null) {
if (VERBOSE) {
System.out.println(" shard " + shardIDX + " s=" + subSearchers[shardIDX] + " totalGroupedHitCount=?" + " " + topGroups.size() + " groups:");
for(SearchGroup<BytesRef> group : topGroups) {
System.out.println(" " + groupToString(group.groupValue) + " groupSort=" + Arrays.toString(group.sortValues));
}
}
shardGroups.add(topGroups);
}
}
final Collection<SearchGroup<BytesRef>> mergedTopGroups = SearchGroup.merge(shardGroups, groupOffset, topNGroups, groupSort);
if (VERBOSE) {
System.out.println(" top groups merged:");
if (mergedTopGroups == null) {
System.out.println(" null");
} else {
System.out.println(" " + mergedTopGroups.size() + " top groups:");
for(SearchGroup<BytesRef> group : mergedTopGroups) {
System.out.println(" [" + groupToString(group.groupValue) + "] groupSort=" + Arrays.toString(group.sortValues));
}
}
}
if (mergedTopGroups != null) {
// Now 2nd pass:
@SuppressWarnings({"unchecked","rawtypes"})
final TopGroups<BytesRef>[] shardTopGroups = new TopGroups[subSearchers.length];
for(int shardIDX=0;shardIDX<subSearchers.length;shardIDX++) {
final TopGroupsCollector<?> secondPassCollector = createSecondPassCollector(firstPassGroupingCollectors.get(shardIDX),
groupField, mergedTopGroups, groupSort, docSort, docOffset + topNDocs, getMaxScores);
subSearchers[shardIDX].search(w, secondPassCollector);
shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0);
if (VERBOSE) {
System.out.println(" " + shardTopGroups[shardIDX].groups.length + " shard[" + shardIDX + "] groups:");
for(GroupDocs<BytesRef> group : shardTopGroups[shardIDX].groups) {
System.out.println(" [" + groupToString(group.groupValue) + "] groupSort=" + Arrays.toString(group.groupSortValues) + " numDocs=" + group.scoreDocs.length);
}
}
}
TopGroups<BytesRef> mergedGroups = TopGroups.merge(shardTopGroups, groupSort, docSort, docOffset, topNDocs, TopGroups.ScoreMergeMode.None);
if (VERBOSE) {
System.out.println(" " + mergedGroups.groups.length + " merged groups:");
for(GroupDocs<BytesRef> group : mergedGroups.groups) {
System.out.println(" [" + groupToString(group.groupValue) + "] groupSort=" + Arrays.toString(group.groupSortValues) + " numDocs=" + group.scoreDocs.length);
}
}
return mergedGroups;
} else {
return null;
}
}
private void assertEquals(int[] docIDtoID, TopGroups<BytesRef> expected, TopGroups<BytesRef> actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean idvBasedImplsUsed) {
if (expected == null) {
assertNull(actual);
return;
}
assertNotNull(actual);
assertEquals("expected.groups.length != actual.groups.length", expected.groups.length, actual.groups.length);
assertEquals("expected.totalHitCount != actual.totalHitCount", expected.totalHitCount, actual.totalHitCount);
assertEquals("expected.totalGroupedHitCount != actual.totalGroupedHitCount", expected.totalGroupedHitCount, actual.totalGroupedHitCount);
if (expected.totalGroupCount != null && verifyTotalGroupCount) {
assertEquals("expected.totalGroupCount != actual.totalGroupCount", expected.totalGroupCount, actual.totalGroupCount);
}
for(int groupIDX=0;groupIDX<expected.groups.length;groupIDX++) {
if (VERBOSE) {
System.out.println(" check groupIDX=" + groupIDX);
}
final GroupDocs<BytesRef> expectedGroup = expected.groups[groupIDX];
final GroupDocs<BytesRef> actualGroup = actual.groups[groupIDX];
if (verifyGroupValues) {
if (idvBasedImplsUsed) {
if (actualGroup.groupValue.length == 0) {
assertNull(expectedGroup.groupValue);
} else {
assertEquals(expectedGroup.groupValue, actualGroup.groupValue);
}
} else {
assertEquals(expectedGroup.groupValue, actualGroup.groupValue);
}
}
assertArrayEquals(expectedGroup.groupSortValues, actualGroup.groupSortValues);
// TODO
// assertEquals(expectedGroup.maxScore, actualGroup.maxScore);
assertEquals(expectedGroup.totalHits.value, actualGroup.totalHits.value);
final ScoreDoc[] expectedFDs = expectedGroup.scoreDocs;
final ScoreDoc[] actualFDs = actualGroup.scoreDocs;
assertEquals(expectedFDs.length, actualFDs.length);
for(int docIDX=0;docIDX<expectedFDs.length;docIDX++) {
final FieldDoc expectedFD = (FieldDoc) expectedFDs[docIDX];
final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX];
//System.out.println(" actual doc=" + docIDtoID[actualFD.doc] + " score=" + actualFD.score);
assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]);
assertArrayEquals(expectedFD.fields, actualFD.fields);
}
}
}
private static class ShardSearcher extends IndexSearcher {
private final List<LeafReaderContext> ctx;
public ShardSearcher(LeafReaderContext ctx, IndexReaderContext parent) {
super(parent);
this.ctx = Collections.singletonList(ctx);
}
public void search(Weight weight, Collector collector) throws IOException {
search(ctx, weight, collector);
}
@Override
public String toString() {
return "ShardSearcher(" + ctx.get(0).reader() + ")";
}
}
}