blob: 2db765abfee9570bfbf420bf2bf484d1a40ff712 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.BytesRefFieldSource;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
public class AllGroupHeadsCollectorTest extends LuceneTestCase {
public void testBasic() throws Exception {
final String groupField = "author";
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy()));
DocValuesType valueType = DocValuesType.SORTED;
// 0
Document doc = new Document();
addGroupField(doc, groupField, "author1", valueType);
doc.add(newTextField("content", "random text", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 1));
doc.add(new SortedDocValuesField("id_2", new BytesRef("1")));
w.addDocument(doc);
// 1
doc = new Document();
addGroupField(doc, groupField, "author1", valueType);
doc.add(newTextField("content", "some more random text blob", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 2));
doc.add(new SortedDocValuesField("id_2", new BytesRef("2")));
w.addDocument(doc);
// 2
doc = new Document();
addGroupField(doc, groupField, "author1", valueType);
doc.add(newTextField("content", "some more random textual data", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 3));
doc.add(new SortedDocValuesField("id_2", new BytesRef("3")));
w.addDocument(doc);
w.commit(); // To ensure a second segment
// 3
doc = new Document();
addGroupField(doc, groupField, "author2", valueType);
doc.add(newTextField("content", "some random text", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 4));
doc.add(new SortedDocValuesField("id_2", new BytesRef("4")));
w.addDocument(doc);
// 4
doc = new Document();
addGroupField(doc, groupField, "author3", valueType);
doc.add(newTextField("content", "some more random text", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 5));
doc.add(new SortedDocValuesField("id_2", new BytesRef("5")));
w.addDocument(doc);
// 5
doc = new Document();
addGroupField(doc, groupField, "author3", valueType);
doc.add(newTextField("content", "random blob", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 6));
doc.add(new SortedDocValuesField("id_2", new BytesRef("6")));
w.addDocument(doc);
// 6 -- no author field
doc = new Document();
doc.add(newTextField("content", "random word stuck in alot of other text", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 6));
doc.add(new SortedDocValuesField("id_2", new BytesRef("6")));
w.addDocument(doc);
// 7 -- no author field
doc = new Document();
doc.add(newTextField("content", "random word stuck in alot of other text", Field.Store.NO));
doc.add(new NumericDocValuesField("id_1", 7));
doc.add(new SortedDocValuesField("id_2", new BytesRef("7")));
w.addDocument(doc);
IndexReader reader = w.getReader();
IndexSearcher indexSearcher = newSearcher(reader);
w.close();
int maxDoc = reader.maxDoc();
Sort sortWithinGroup = new Sort(new SortField("id_1", SortField.Type.INT, true));
AllGroupHeadsCollector<?> allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup);
indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector);
assertTrue(arrayContains(new int[]{2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads()));
assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc));
allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup);
indexSearcher.search(new TermQuery(new Term("content", "some")), allGroupHeadsCollector);
assertTrue(arrayContains(new int[]{2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads()));
assertTrue(openBitSetContains(new int[]{2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc));
allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup);
indexSearcher.search(new TermQuery(new Term("content", "blob")), allGroupHeadsCollector);
assertTrue(arrayContains(new int[]{1, 5}, allGroupHeadsCollector.retrieveGroupHeads()));
assertTrue(openBitSetContains(new int[]{1, 5}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc));
// STRING sort type triggers different implementation
Sort sortWithinGroup2 = new Sort(new SortField("id_2", SortField.Type.STRING, true));
allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup2);
indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector);
assertTrue(arrayContains(new int[]{2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads()));
assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc));
Sort sortWithinGroup3 = new Sort(new SortField("id_2", SortField.Type.STRING, false));
allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup3);
indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector);
// 7 b/c higher doc id wins, even if order of field is in not in reverse.
assertTrue(arrayContains(new int[]{0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads()));
assertTrue(openBitSetContains(new int[]{0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc));
indexSearcher.getIndexReader().close();
dir.close();
}
public void testRandom() throws Exception {
int numberOfRuns = atLeast(1);
for (int iter = 0; iter < numberOfRuns; iter++) {
if (VERBOSE) {
System.out.println(String.format(Locale.ROOT, "TEST: iter=%d total=%d", iter, numberOfRuns));
}
final int numDocs = TestUtil.nextInt(random(), 100, 1000) * RANDOM_MULTIPLIER;
final int numGroups = TestUtil.nextInt(random(), 1, numDocs);
if (VERBOSE) {
System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups);
}
final List<BytesRef> groups = new ArrayList<>();
for (int i = 0; i < numGroups; i++) {
String randomValue;
do {
// B/c of DV based impl we can't see the difference between an empty string and a null value.
// For that reason we don't generate empty string groups.
randomValue = TestUtil.randomRealisticUnicodeString(random());
//randomValue = TestUtil.randomSimpleString(random());
} while ("".equals(randomValue));
groups.add(new BytesRef(randomValue));
}
final String[] contentStrings = new String[TestUtil.nextInt(random(), 2, 20)];
if (VERBOSE) {
System.out.println("TEST: create fake content");
}
for (int contentIDX = 0; contentIDX < contentStrings.length; contentIDX++) {
final StringBuilder sb = new StringBuilder();
sb.append("real").append(random().nextInt(3)).append(' ');
final int fakeCount = random().nextInt(10);
for (int fakeIDX = 0; fakeIDX < fakeCount; fakeIDX++) {
sb.append("fake ");
}
contentStrings[contentIDX] = sb.toString();
if (VERBOSE) {
System.out.println(" content=" + sb.toString());
}
}
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())));
DocValuesType valueType = DocValuesType.SORTED;
Document doc = new Document();
Document docNoGroup = new Document();
Field valuesField = null;
valuesField = new SortedDocValuesField("group", new BytesRef());
doc.add(valuesField);
Field sort1 = new SortedDocValuesField("sort1", new BytesRef());
doc.add(sort1);
docNoGroup.add(sort1);
Field sort2 = new SortedDocValuesField("sort2", new BytesRef());
doc.add(sort2);
docNoGroup.add(sort2);
Field sort3 = new SortedDocValuesField("sort3", new BytesRef());
doc.add(sort3);
docNoGroup.add(sort3);
Field content = newTextField("content", "", Field.Store.NO);
doc.add(content);
docNoGroup.add(content);
NumericDocValuesField idDV = new NumericDocValuesField("id", 0);
doc.add(idDV);
docNoGroup.add(idDV);
final GroupDoc[] groupDocs = new GroupDoc[numDocs];
for (int i = 0; i < numDocs; i++) {
final BytesRef groupValue;
if (random().nextInt(24) == 17) {
// So we test the "doc doesn't have the group'd
// field" case:
groupValue = null;
} else {
groupValue = groups.get(random().nextInt(groups.size()));
}
final GroupDoc groupDoc = new GroupDoc(
i,
groupValue,
groups.get(random().nextInt(groups.size())),
groups.get(random().nextInt(groups.size())),
new BytesRef(String.format(Locale.ROOT, "%05d", i)),
contentStrings[random().nextInt(contentStrings.length)]
);
if (VERBOSE) {
System.out.println(" doc content=" + groupDoc.content + " id=" + i + " group=" + (groupDoc.group == null ? "null" : groupDoc.group.utf8ToString()) + " sort1=" + groupDoc.sort1.utf8ToString() + " sort2=" + groupDoc.sort2.utf8ToString() + " sort3=" + groupDoc.sort3.utf8ToString());
}
groupDocs[i] = groupDoc;
if (groupDoc.group != null) {
valuesField.setBytesValue(new BytesRef(groupDoc.group.utf8ToString()));
}
sort1.setBytesValue(groupDoc.sort1);
sort2.setBytesValue(groupDoc.sort2);
sort3.setBytesValue(groupDoc.sort3);
content.setStringValue(groupDoc.content);
idDV.setLongValue(groupDoc.id);
if (groupDoc.group == null) {
w.addDocument(docNoGroup);
} else {
w.addDocument(doc);
}
}
final DirectoryReader r = w.getReader();
w.close();
NumericDocValues values = MultiDocValues.getNumericValues(r, "id");
final int[] docIDToFieldId = new int[numDocs];
final int[] fieldIdToDocID = new int[numDocs];
for (int i = 0; i < numDocs; i++) {
assertEquals(i, values.nextDoc());
int fieldId = (int) values.longValue();
docIDToFieldId[i] = fieldId;
fieldIdToDocID[fieldId] = i;
}
final IndexSearcher s = newSearcher(r);
Set<Integer> seenIDs = new HashSet<>();
for (int contentID = 0; contentID < 3; contentID++) {
final ScoreDoc[] hits = s.search(new TermQuery(new Term("content", "real" + contentID)), numDocs).scoreDocs;
for (ScoreDoc hit : hits) {
int idValue = docIDToFieldId[hit.doc];
final GroupDoc gd = groupDocs[idValue];
assertEquals(gd.id, idValue);
seenIDs.add(idValue);
assertTrue(gd.score == 0.0);
gd.score = hit.score;
}
}
// make sure all groups were seen across the hits
assertEquals(groupDocs.length, seenIDs.size());
// make sure scores are sane
for (GroupDoc gd : groupDocs) {
assertTrue(Float.isFinite(gd.score));
assertTrue(gd.score >= 0.0);
}
for (int searchIter = 0; searchIter < 100; searchIter++) {
if (VERBOSE) {
System.out.println("TEST: searchIter=" + searchIter);
}
final String searchTerm = "real" + random().nextInt(3);
boolean sortByScoreOnly = random().nextBoolean();
Sort sortWithinGroup = getRandomSort(sortByScoreOnly);
AllGroupHeadsCollector<?> allGroupHeadsCollector = createRandomCollector("group", sortWithinGroup);
s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollector);
int[] expectedGroupHeads = createExpectedGroupHeads(searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID);
int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads();
// The actual group heads contains Lucene ids. Need to change them into our id value.
for (int i = 0; i < actualGroupHeads.length; i++) {
actualGroupHeads[i] = docIDToFieldId[actualGroupHeads[i]];
}
// Allows us the easily iterate and assert the actual and expected results.
Arrays.sort(expectedGroupHeads);
Arrays.sort(actualGroupHeads);
if (VERBOSE) {
System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName());
System.out.println("Sort within group: " + sortWithinGroup);
System.out.println("Num group: " + numGroups);
System.out.println("Num doc: " + numDocs);
System.out.println("\n=== Expected: \n");
for (int expectedDocId : expectedGroupHeads) {
GroupDoc expectedGroupDoc = groupDocs[expectedDocId];
String expectedGroup = expectedGroupDoc.group == null ? null : expectedGroupDoc.group.utf8ToString();
System.out.println(
String.format(Locale.ROOT,
"Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d",
expectedGroup, expectedGroupDoc.score, expectedGroupDoc.sort1.utf8ToString(),
expectedGroupDoc.sort2.utf8ToString(), expectedGroupDoc.sort3.utf8ToString(), expectedDocId
)
);
}
System.out.println("\n=== Actual: \n");
for (int actualDocId : actualGroupHeads) {
GroupDoc actualGroupDoc = groupDocs[actualDocId];
String actualGroup = actualGroupDoc.group == null ? null : actualGroupDoc.group.utf8ToString();
System.out.println(
String.format(Locale.ROOT,
"Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d",
actualGroup, actualGroupDoc.score, actualGroupDoc.sort1.utf8ToString(),
actualGroupDoc.sort2.utf8ToString(), actualGroupDoc.sort3.utf8ToString(), actualDocId
)
);
}
System.out.println("\n===================================================================================");
}
assertArrayEquals(expectedGroupHeads, actualGroupHeads);
}
r.close();
dir.close();
}
}
private boolean arrayContains(int[] expected, int[] actual) {
Arrays.sort(actual); // in some cases the actual docs aren't sorted by docid. This method expects that.
if (expected.length != actual.length) {
return false;
}
for (int e : expected) {
boolean found = false;
for (int a : actual) {
if (e == a) {
found = true;
break;
}
}
if (!found) {
return false;
}
}
return true;
}
private boolean openBitSetContains(int[] expectedDocs, Bits actual, int maxDoc) throws IOException {
assert actual instanceof FixedBitSet;
if (expectedDocs.length != ((FixedBitSet)actual).cardinality()) {
return false;
}
FixedBitSet expected = new FixedBitSet(maxDoc);
for (int expectedDoc : expectedDocs) {
expected.set(expectedDoc);
}
for (int docId = expected.nextSetBit(0); docId != DocIdSetIterator.NO_MORE_DOCS; docId = docId + 1 >= expected.length() ? DocIdSetIterator.NO_MORE_DOCS : expected.nextSetBit(docId + 1)) {
if (!actual.get(docId)) {
return false;
}
}
return true;
}
private int[] createExpectedGroupHeads(String searchTerm, GroupDoc[] groupDocs, Sort docSort, boolean sortByScoreOnly, int[] fieldIdToDocID) {
Map<BytesRef, List<GroupDoc>> groupHeads = new HashMap<>();
for (GroupDoc groupDoc : groupDocs) {
if (!groupDoc.content.startsWith(searchTerm)) {
continue;
}
if (!groupHeads.containsKey(groupDoc.group)) {
List<GroupDoc> list = new ArrayList<>();
list.add(groupDoc);
groupHeads.put(groupDoc.group, list);
continue;
}
groupHeads.get(groupDoc.group).add(groupDoc);
}
int[] allGroupHeads = new int[groupHeads.size()];
int i = 0;
for (BytesRef groupValue : groupHeads.keySet()) {
List<GroupDoc> docs = groupHeads.get(groupValue);
Collections.sort(docs, getComparator(docSort, sortByScoreOnly, fieldIdToDocID));
allGroupHeads[i++] = docs.get(0).id;
}
return allGroupHeads;
}
private Sort getRandomSort(boolean scoreOnly) {
final List<SortField> sortFields = new ArrayList<>();
if (random().nextInt(7) == 2 || scoreOnly) {
sortFields.add(SortField.FIELD_SCORE);
} else {
if (random().nextBoolean()) {
if (random().nextBoolean()) {
sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
} else {
sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
}
} else if (random().nextBoolean()) {
sortFields.add(new SortField("sort1", SortField.Type.STRING, random().nextBoolean()));
sortFields.add(new SortField("sort2", SortField.Type.STRING, random().nextBoolean()));
}
}
// Break ties:
if (random().nextBoolean() && !scoreOnly) {
sortFields.add(new SortField("sort3", SortField.Type.STRING));
} else if (!scoreOnly) {
sortFields.add(new SortField("id", SortField.Type.INT));
}
return new Sort(sortFields.toArray(new SortField[sortFields.size()]));
}
private Comparator<GroupDoc> getComparator(Sort sort, final boolean sortByScoreOnly, final int[] fieldIdToDocID) {
final SortField[] sortFields = sort.getSort();
return new Comparator<GroupDoc>() {
@Override
public int compare(GroupDoc d1, GroupDoc d2) {
for (SortField sf : sortFields) {
final int cmp;
if (sf.getType() == SortField.Type.SCORE) {
if (d1.score > d2.score) {
cmp = -1;
} else if (d1.score < d2.score) {
cmp = 1;
} else {
cmp = sortByScoreOnly ? fieldIdToDocID[d1.id] - fieldIdToDocID[d2.id] : 0;
}
} else if (sf.getField().equals("sort1")) {
cmp = d1.sort1.compareTo(d2.sort1);
} else if (sf.getField().equals("sort2")) {
cmp = d1.sort2.compareTo(d2.sort2);
} else if (sf.getField().equals("sort3")) {
cmp = d1.sort3.compareTo(d2.sort3);
} else {
assertEquals(sf.getField(), "id");
cmp = d1.id - d2.id;
}
if (cmp != 0) {
return sf.getReverse() ? -cmp : cmp;
}
}
// Our sort always fully tie breaks:
fail();
return 0;
}
};
}
@SuppressWarnings({"unchecked","rawtypes"})
private AllGroupHeadsCollector<?> createRandomCollector(String groupField, Sort sortWithinGroup) {
if (random().nextBoolean()) {
ValueSource vs = new BytesRefFieldSource(groupField);
return AllGroupHeadsCollector.newCollector(new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup);
} else {
return AllGroupHeadsCollector.newCollector(new TermGroupSelector(groupField), sortWithinGroup);
}
}
private void addGroupField(Document doc, String groupField, String value, DocValuesType valueType) {
Field valuesField = null;
switch(valueType) {
case BINARY:
valuesField = new BinaryDocValuesField(groupField, new BytesRef(value));
break;
case SORTED:
valuesField = new SortedDocValuesField(groupField, new BytesRef(value));
break;
default:
fail("unhandled type");
}
doc.add(valuesField);
}
private static class GroupDoc {
final int id;
final BytesRef group;
final BytesRef sort1;
final BytesRef sort2;
final BytesRef sort3;
// content must be "realN ..."
final String content;
float score;
public GroupDoc(int id, BytesRef group, BytesRef sort1, BytesRef sort2, BytesRef sort3, String content) {
this.id = id;
this.group = group;
this.sort1 = sort1;
this.sort2 = sort2;
this.sort3 = sort3;
this.content = content;
}
}
}