blob: 01e9928222560192e55be2cdf400948a1f8df03c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;
import java.io.IOException;
import java.util.Collection;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.util.ArrayUtil;
/**
* A second-pass collector that collects the TopDocs for each group, and
* returns them as a {@link TopGroups} object
*
* @param <T> the type of the group value
*/
public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> {
private final Sort groupSort;
private final Sort withinGroupSort;
private final int maxDocsPerGroup;
/**
* Create a new TopGroupsCollector
* @param groupSelector the group selector used to define groups
* @param groups the groups to collect TopDocs for
* @param groupSort the order in which groups are returned
* @param withinGroupSort the order in which documents are sorted in each group
* @param maxDocsPerGroup the maximum number of docs to collect for each group
* @param getMaxScores if true, record the maximum score for each group
*/
public TopGroupsCollector(GroupSelector<T> groupSelector, Collection<SearchGroup<T>> groups, Sort groupSort, Sort withinGroupSort,
int maxDocsPerGroup, boolean getMaxScores) {
super(groupSelector, groups,
new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores));
this.groupSort = Objects.requireNonNull(groupSort);
this.withinGroupSort = Objects.requireNonNull(withinGroupSort);
this.maxDocsPerGroup = maxDocsPerGroup;
}
private static class MaxScoreCollector extends SimpleCollector {
private Scorable scorer;
private float maxScore = Float.MIN_VALUE;
private boolean collectedAnyHits = false;
public MaxScoreCollector() {}
public float getMaxScore() {
return collectedAnyHits ? maxScore : Float.NaN;
}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE;
}
@Override
public void setScorer(Scorable scorer) {
this.scorer = scorer;
}
@Override
public void collect(int doc) throws IOException {
collectedAnyHits = true;
maxScore = Math.max(scorer.score(), maxScore);
}
}
private static class TopDocsAndMaxScoreCollector extends FilterCollector {
private final TopDocsCollector<?> topDocsCollector;
private final MaxScoreCollector maxScoreCollector;
private final boolean sortedByScore;
public TopDocsAndMaxScoreCollector(boolean sortedByScore, TopDocsCollector<?> topDocsCollector, MaxScoreCollector maxScoreCollector) {
super(MultiCollector.wrap(topDocsCollector, maxScoreCollector));
this.sortedByScore = sortedByScore;
this.topDocsCollector = topDocsCollector;
this.maxScoreCollector = maxScoreCollector;
}
}
private static class TopDocsReducer<T> extends GroupReducer<T, TopDocsAndMaxScoreCollector> {
private final Supplier<TopDocsAndMaxScoreCollector> supplier;
private final boolean needsScores;
TopDocsReducer(Sort withinGroupSort,
int maxDocsPerGroup, boolean getMaxScores) {
this.needsScores = getMaxScores || withinGroupSort.needsScores();
if (withinGroupSort == Sort.RELEVANCE) {
supplier = () -> new TopDocsAndMaxScoreCollector(true, TopScoreDocCollector.create(maxDocsPerGroup, Integer.MAX_VALUE), null);
} else {
supplier = () -> {
TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, Integer.MAX_VALUE); // TODO: disable exact counts?
MaxScoreCollector maxScoreCollector = getMaxScores ? new MaxScoreCollector() : null;
return new TopDocsAndMaxScoreCollector(false, topDocsCollector, maxScoreCollector);
};
}
}
@Override
public boolean needsScores() {
return needsScores;
}
@Override
protected TopDocsAndMaxScoreCollector newCollector() {
return supplier.get();
}
}
/**
* Get the TopGroups recorded by this collector
* @param withinGroupOffset the offset within each group to start collecting documents
*/
public TopGroups<T> getTopGroups(int withinGroupOffset) {
@SuppressWarnings({"unchecked","rawtypes"})
final GroupDocs<T>[] groupDocsResult = (GroupDocs<T>[]) new GroupDocs[groups.size()];
int groupIDX = 0;
float maxScore = Float.MIN_VALUE;
for(SearchGroup<T> group : groups) {
TopDocsAndMaxScoreCollector collector = (TopDocsAndMaxScoreCollector) groupReducer.getCollector(group.groupValue);
final TopDocs topDocs;
final float groupMaxScore;
if (collector.sortedByScore) {
TopDocs allTopDocs = collector.topDocsCollector.topDocs();
groupMaxScore = allTopDocs.scoreDocs.length == 0 ? Float.NaN : allTopDocs.scoreDocs[0].score;
if (allTopDocs.scoreDocs.length <= withinGroupOffset) {
topDocs = new TopDocs(allTopDocs.totalHits, new ScoreDoc[0]);
} else {
topDocs = new TopDocs(allTopDocs.totalHits, ArrayUtil.copyOfSubArray(allTopDocs.scoreDocs, withinGroupOffset, Math.min(allTopDocs.scoreDocs.length, withinGroupOffset + maxDocsPerGroup)));
}
} else {
topDocs = collector.topDocsCollector.topDocs(withinGroupOffset, maxDocsPerGroup);
if (collector.maxScoreCollector == null) {
groupMaxScore = Float.NaN;
} else {
groupMaxScore = collector.maxScoreCollector.getMaxScore();
}
}
groupDocsResult[groupIDX++] = new GroupDocs<>(Float.NaN,
groupMaxScore,
topDocs.totalHits,
topDocs.scoreDocs,
group.groupValue,
group.sortValues);
maxScore = Math.max(maxScore, groupMaxScore);
}
return new TopGroups<>(groupSort.getSort(),
withinGroupSort.getSort(),
totalHitCount, totalGroupedHitCount, groupDocsResult,
maxScore);
}
}