blob: 610ded9b3106bd4a9899d870ca35994275a20316 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
public abstract class BaseGroupSelectorTestCase<T> extends AbstractGroupingTestCase {
protected abstract void addGroupField(Document document, int id);
protected abstract GroupSelector<T> getGroupSelector();
protected abstract Query filterQuery(T groupValue);
public void testSortByRelevance() throws IOException {
Shard shard = new Shard();
indexRandomDocs(shard.writer);
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
IndexSearcher searcher = shard.getIndexSearcher();
GroupingSearch grouper = new GroupingSearch(getGroupSelector());
grouper.setGroupDocsLimit(10);
TopGroups<T> topGroups = grouper.search(searcher, topLevel, 0, 5);
TopDocs topDoc = searcher.search(topLevel, 1);
for (int i = 0; i < topGroups.groups.length; i++) {
// Each group should have a result set equal to that returned by the top-level query,
// filtered by the group value.
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(topGroups.groups[i].groupValue), BooleanClause.Occur.FILTER)
.build();
TopDocs td = searcher.search(filtered, 10);
assertScoreDocsEquals(topGroups.groups[i].scoreDocs, td.scoreDocs);
if (i == 0) {
assertEquals(td.scoreDocs[0].doc, topDoc.scoreDocs[0].doc);
assertEquals(td.scoreDocs[0].score, topDoc.scoreDocs[0].score, 0);
}
}
shard.close();
}
public void testSortGroups() throws IOException {
Shard shard = new Shard();
indexRandomDocs(shard.writer);
IndexSearcher searcher = shard.getIndexSearcher();
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
GroupingSearch grouper = new GroupingSearch(getGroupSelector());
grouper.setGroupDocsLimit(10);
Sort sort = new Sort(new SortField("sort1", SortField.Type.STRING), new SortField("sort2", SortField.Type.LONG));
grouper.setGroupSort(sort);
TopGroups<T> topGroups = grouper.search(searcher, topLevel, 0, 5);
TopDocs topDoc = searcher.search(topLevel, 1, sort);
for (int i = 0; i < topGroups.groups.length; i++) {
// We're sorting the groups by a defined Sort, but each group itself should be ordered
// by doc relevance, and should be equal to the results of a top-level query filtered
// by the group value
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(topGroups.groups[i].groupValue), BooleanClause.Occur.FILTER)
.build();
TopDocs td = searcher.search(filtered, 10);
assertScoreDocsEquals(topGroups.groups[i].scoreDocs, td.scoreDocs);
// The top group should have sort values equal to the sort values of the top doc of
// a top-level search sorted by the same Sort; subsequent groups should have sort values
// that compare lower than their predecessor.
if (i > 0) {
assertSortsBefore(topGroups.groups[i - 1], topGroups.groups[i]);
} else {
assertArrayEquals(((FieldDoc)topDoc.scoreDocs[0]).fields, topGroups.groups[0].groupSortValues);
}
}
shard.close();
}
public void testSortWithinGroups() throws IOException {
Shard shard = new Shard();
indexRandomDocs(shard.writer);
IndexSearcher searcher = shard.getIndexSearcher();
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
GroupingSearch grouper = new GroupingSearch(getGroupSelector());
grouper.setGroupDocsLimit(10);
Sort sort = new Sort(new SortField("sort1", SortField.Type.STRING), new SortField("sort2", SortField.Type.LONG));
grouper.setSortWithinGroup(sort);
TopGroups<T> topGroups = grouper.search(searcher, topLevel, 0, 5);
TopDocs topDoc = searcher.search(topLevel, 1);
for (int i = 0; i < topGroups.groups.length; i++) {
// Check top-level ordering by score: first group's maxScore should be equal to the
// top score returned by a simple search with no grouping; subsequent groups should
// all have equal or lower maxScores
if (i == 0) {
assertEquals(topDoc.scoreDocs[0].score, topGroups.groups[0].maxScore, 0);
} else {
assertTrue(topGroups.groups[i].maxScore <= topGroups.groups[i - 1].maxScore);
}
// Groups themselves are ordered by a defined Sort, and each should give the same result as
// the top-level query, filtered by the group value, with the same Sort
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(topGroups.groups[i].groupValue), BooleanClause.Occur.FILTER)
.build();
TopDocs td = searcher.search(filtered, 10, sort);
assertScoreDocsEquals(td.scoreDocs, topGroups.groups[i].scoreDocs);
}
shard.close();
}
public void testGroupHeads() throws IOException {
Shard shard = new Shard();
indexRandomDocs(shard.writer);
IndexSearcher searcher = shard.getIndexSearcher();
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
GroupSelector<T> groupSelector = getGroupSelector();
GroupingSearch grouping = new GroupingSearch(groupSelector);
grouping.setAllGroups(true);
grouping.setAllGroupHeads(true);
grouping.search(searcher, topLevel, 0, 1);
Collection<T> matchingGroups = grouping.getAllMatchingGroups();
// The number of hits from the top-level query should equal the sum of
// the number of hits from the query filtered by each group value in turn
int totalHits = searcher.count(topLevel);
int groupHits = 0;
for (T groupValue : matchingGroups) {
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(groupValue), BooleanClause.Occur.FILTER)
.build();
groupHits += searcher.count(filtered);
}
assertEquals(totalHits, groupHits);
Bits groupHeads = grouping.getAllGroupHeads();
int cardinality = 0;
for (int i = 0; i < groupHeads.length(); i++) {
if (groupHeads.get(i)) {
cardinality++;
}
}
assertEquals(matchingGroups.size(), cardinality); // We should have one set bit per matching group
// Each group head should correspond to the topdoc of a search filtered by
// that group
for (T groupValue : matchingGroups) {
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(groupValue), BooleanClause.Occur.FILTER)
.build();
TopDocs td = searcher.search(filtered, 1);
assertTrue(groupHeads.get(td.scoreDocs[0].doc));
}
shard.close();
}
public void testGroupHeadsWithSort() throws IOException {
Shard shard = new Shard();
indexRandomDocs(shard.writer);
IndexSearcher searcher = shard.getIndexSearcher();
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
Sort sort = new Sort(new SortField("sort1", SortField.Type.STRING), new SortField("sort2", SortField.Type.LONG));
GroupSelector<T> groupSelector = getGroupSelector();
GroupingSearch grouping = new GroupingSearch(groupSelector);
grouping.setAllGroups(true);
grouping.setAllGroupHeads(true);
grouping.setSortWithinGroup(sort);
grouping.search(searcher, topLevel, 0, 1);
Collection<T> matchingGroups = grouping.getAllMatchingGroups();
Bits groupHeads = grouping.getAllGroupHeads();
int cardinality = 0;
for (int i = 0; i < groupHeads.length(); i++) {
if (groupHeads.get(i)) {
cardinality++;
}
}
assertEquals(matchingGroups.size(), cardinality); // We should have one set bit per matching group
// Each group head should correspond to the topdoc of a search filtered by
// that group using the same within-group sort
for (T groupValue : matchingGroups) {
Query filtered = new BooleanQuery.Builder()
.add(topLevel, BooleanClause.Occur.MUST)
.add(filterQuery(groupValue), BooleanClause.Occur.FILTER)
.build();
TopDocs td = searcher.search(filtered, 1, sort);
assertTrue(groupHeads.get(td.scoreDocs[0].doc));
}
shard.close();
}
public void testShardedGrouping() throws IOException {
Shard control = new Shard();
int shardCount = random().nextInt(3) + 2; // between 2 and 4 shards
Shard[] shards = new Shard[shardCount];
for (int i = 0; i < shardCount; i++) {
shards[i] = new Shard();
}
String[] texts = new String[]{ "foo", "bar", "bar baz", "foo foo bar" };
// Create a bunch of random documents, and index them - once into the control index,
// and once into a randomly picked shard.
int numDocs = atLeast(200);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new NumericDocValuesField("id", i));
doc.add(new TextField("name", Integer.toString(i), Field.Store.YES));
doc.add(new TextField("text", texts[random().nextInt(texts.length)], Field.Store.NO));
doc.add(new SortedDocValuesField("sort1", new BytesRef("sort" + random().nextInt(4))));
doc.add(new NumericDocValuesField("sort2", random().nextLong()));
addGroupField(doc, i);
control.writer.addDocument(doc);
int shard = random().nextInt(shardCount);
shards[shard].writer.addDocument(doc);
}
String[] query = new String[]{ "foo", "bar", "baz" };
Query topLevel = new TermQuery(new Term("text", query[random().nextInt(query.length)]));
Sort sort = new Sort(new SortField("sort1", SortField.Type.STRING), new SortField("sort2", SortField.Type.LONG));
// A grouped query run in two phases against the control should give us the same
// result as the query run against shards and merged back together after each phase.
FirstPassGroupingCollector<T> singletonFirstPass = new FirstPassGroupingCollector<>(getGroupSelector(), sort, 5);
control.getIndexSearcher().search(topLevel, singletonFirstPass);
Collection<SearchGroup<T>> singletonGroups = singletonFirstPass.getTopGroups(0);
List<Collection<SearchGroup<T>>> shardGroups = new ArrayList<>();
for (Shard shard : shards) {
FirstPassGroupingCollector<T> fc = new FirstPassGroupingCollector<>(getGroupSelector(), sort, 5);
shard.getIndexSearcher().search(topLevel, fc);
shardGroups.add(fc.getTopGroups(0));
}
Collection<SearchGroup<T>> mergedGroups = SearchGroup.merge(shardGroups, 0, 5, sort);
assertEquals(singletonGroups, mergedGroups);
TopGroupsCollector<T> singletonSecondPass = new TopGroupsCollector<>(getGroupSelector(), singletonGroups, sort,
Sort.RELEVANCE, 5, true);
control.getIndexSearcher().search(topLevel, singletonSecondPass);
TopGroups<T> singletonTopGroups = singletonSecondPass.getTopGroups(0);
// TODO why does SearchGroup.merge() take a list but TopGroups.merge() take an array?
@SuppressWarnings("unchecked")
TopGroups<T>[] shardTopGroups = new TopGroups[shards.length];
int j = 0;
for (Shard shard : shards) {
TopGroupsCollector<T> sc = new TopGroupsCollector<>(getGroupSelector(), mergedGroups, sort, Sort.RELEVANCE, 5, true);
shard.getIndexSearcher().search(topLevel, sc);
shardTopGroups[j] = sc.getTopGroups(0);
j++;
}
TopGroups<T> mergedTopGroups = TopGroups.merge(shardTopGroups, sort, Sort.RELEVANCE, 0, 5, TopGroups.ScoreMergeMode.None);
assertNotNull(mergedTopGroups);
assertEquals(singletonTopGroups.totalGroupedHitCount, mergedTopGroups.totalGroupedHitCount);
assertEquals(singletonTopGroups.totalHitCount, mergedTopGroups.totalHitCount);
assertEquals(singletonTopGroups.totalGroupCount, mergedTopGroups.totalGroupCount);
assertEquals(singletonTopGroups.groups.length, mergedTopGroups.groups.length);
for (int i = 0; i < singletonTopGroups.groups.length; i++) {
assertEquals(singletonTopGroups.groups[i].groupValue, mergedTopGroups.groups[i].groupValue);
assertEquals(singletonTopGroups.groups[i].scoreDocs.length, mergedTopGroups.groups[i].scoreDocs.length);
}
control.close();
for (Shard shard : shards) {
shard.close();
}
}
private void indexRandomDocs(RandomIndexWriter w) throws IOException {
String[] texts = new String[]{ "foo", "bar", "bar baz", "foo foo bar" };
int numDocs = atLeast(200);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new NumericDocValuesField("id", i));
doc.add(new TextField("name", Integer.toString(i), Field.Store.YES));
doc.add(new TextField("text", texts[random().nextInt(texts.length)], Field.Store.NO));
doc.add(new SortedDocValuesField("sort1", new BytesRef("sort" + random().nextInt(4))));
doc.add(new NumericDocValuesField("sort2", random().nextLong()));
addGroupField(doc, i);
w.addDocument(doc);
}
}
private void assertSortsBefore(GroupDocs<T> first, GroupDocs<T> second) {
Object[] groupSortValues = second.groupSortValues;
Object[] prevSortValues = first.groupSortValues;
assertTrue(((BytesRef)prevSortValues[0]).compareTo((BytesRef)groupSortValues[0]) <= 0);
if (prevSortValues[0].equals(groupSortValues[0])) {
assertTrue((long)prevSortValues[1] <= (long)groupSortValues[1]);
}
}
}