| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.lucene.search.grouping; |
| |
| import java.io.IOException; |
| import java.util.Collection; |
| import java.util.Objects; |
| import java.util.function.Supplier; |
| |
| import org.apache.lucene.search.FilterCollector; |
| import org.apache.lucene.search.MultiCollector; |
| import org.apache.lucene.search.Scorable; |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.ScoreMode; |
| import org.apache.lucene.search.SimpleCollector; |
| import org.apache.lucene.search.Sort; |
| import org.apache.lucene.search.TopDocs; |
| import org.apache.lucene.search.TopDocsCollector; |
| import org.apache.lucene.search.TopFieldCollector; |
| import org.apache.lucene.search.TopScoreDocCollector; |
| import org.apache.lucene.util.ArrayUtil; |
| |
| /** |
| * A second-pass collector that collects the TopDocs for each group, and |
| * returns them as a {@link TopGroups} object |
| * |
| * @param <T> the type of the group value |
| */ |
| public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> { |
| |
| private final Sort groupSort; |
| private final Sort withinGroupSort; |
| private final int maxDocsPerGroup; |
| |
| /** |
| * Create a new TopGroupsCollector |
| * @param groupSelector the group selector used to define groups |
| * @param groups the groups to collect TopDocs for |
| * @param groupSort the order in which groups are returned |
| * @param withinGroupSort the order in which documents are sorted in each group |
| * @param maxDocsPerGroup the maximum number of docs to collect for each group |
| * @param getMaxScores if true, record the maximum score for each group |
| */ |
| public TopGroupsCollector(GroupSelector<T> groupSelector, Collection<SearchGroup<T>> groups, Sort groupSort, Sort withinGroupSort, |
| int maxDocsPerGroup, boolean getMaxScores) { |
| super(groupSelector, groups, |
| new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores)); |
| this.groupSort = Objects.requireNonNull(groupSort); |
| this.withinGroupSort = Objects.requireNonNull(withinGroupSort); |
| this.maxDocsPerGroup = maxDocsPerGroup; |
| |
| } |
| |
| private static class MaxScoreCollector extends SimpleCollector { |
| private Scorable scorer; |
| private float maxScore = Float.MIN_VALUE; |
| private boolean collectedAnyHits = false; |
| |
| public MaxScoreCollector() {} |
| |
| public float getMaxScore() { |
| return collectedAnyHits ? maxScore : Float.NaN; |
| } |
| |
| @Override |
| public ScoreMode scoreMode() { |
| return ScoreMode.COMPLETE; |
| } |
| |
| @Override |
| public void setScorer(Scorable scorer) { |
| this.scorer = scorer; |
| } |
| |
| @Override |
| public void collect(int doc) throws IOException { |
| collectedAnyHits = true; |
| maxScore = Math.max(scorer.score(), maxScore); |
| } |
| } |
| |
| private static class TopDocsAndMaxScoreCollector extends FilterCollector { |
| private final TopDocsCollector<?> topDocsCollector; |
| private final MaxScoreCollector maxScoreCollector; |
| private final boolean sortedByScore; |
| |
| public TopDocsAndMaxScoreCollector(boolean sortedByScore, TopDocsCollector<?> topDocsCollector, MaxScoreCollector maxScoreCollector) { |
| super(MultiCollector.wrap(topDocsCollector, maxScoreCollector)); |
| this.sortedByScore = sortedByScore; |
| this.topDocsCollector = topDocsCollector; |
| this.maxScoreCollector = maxScoreCollector; |
| } |
| } |
| |
| private static class TopDocsReducer<T> extends GroupReducer<T, TopDocsAndMaxScoreCollector> { |
| |
| private final Supplier<TopDocsAndMaxScoreCollector> supplier; |
| private final boolean needsScores; |
| |
| TopDocsReducer(Sort withinGroupSort, |
| int maxDocsPerGroup, boolean getMaxScores) { |
| this.needsScores = getMaxScores || withinGroupSort.needsScores(); |
| if (withinGroupSort == Sort.RELEVANCE) { |
| supplier = () -> new TopDocsAndMaxScoreCollector(true, TopScoreDocCollector.create(maxDocsPerGroup, Integer.MAX_VALUE), null); |
| } else { |
| supplier = () -> { |
| TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, Integer.MAX_VALUE); // TODO: disable exact counts? |
| MaxScoreCollector maxScoreCollector = getMaxScores ? new MaxScoreCollector() : null; |
| return new TopDocsAndMaxScoreCollector(false, topDocsCollector, maxScoreCollector); |
| }; |
| } |
| } |
| |
| @Override |
| public boolean needsScores() { |
| return needsScores; |
| } |
| |
| @Override |
| protected TopDocsAndMaxScoreCollector newCollector() { |
| return supplier.get(); |
| } |
| } |
| |
| /** |
| * Get the TopGroups recorded by this collector |
| * @param withinGroupOffset the offset within each group to start collecting documents |
| */ |
| public TopGroups<T> getTopGroups(int withinGroupOffset) { |
| @SuppressWarnings({"unchecked","rawtypes"}) |
| final GroupDocs<T>[] groupDocsResult = (GroupDocs<T>[]) new GroupDocs[groups.size()]; |
| |
| int groupIDX = 0; |
| float maxScore = Float.MIN_VALUE; |
| for(SearchGroup<T> group : groups) { |
| TopDocsAndMaxScoreCollector collector = (TopDocsAndMaxScoreCollector) groupReducer.getCollector(group.groupValue); |
| final TopDocs topDocs; |
| final float groupMaxScore; |
| if (collector.sortedByScore) { |
| TopDocs allTopDocs = collector.topDocsCollector.topDocs(); |
| groupMaxScore = allTopDocs.scoreDocs.length == 0 ? Float.NaN : allTopDocs.scoreDocs[0].score; |
| if (allTopDocs.scoreDocs.length <= withinGroupOffset) { |
| topDocs = new TopDocs(allTopDocs.totalHits, new ScoreDoc[0]); |
| } else { |
| topDocs = new TopDocs(allTopDocs.totalHits, ArrayUtil.copyOfSubArray(allTopDocs.scoreDocs, withinGroupOffset, Math.min(allTopDocs.scoreDocs.length, withinGroupOffset + maxDocsPerGroup))); |
| } |
| } else { |
| topDocs = collector.topDocsCollector.topDocs(withinGroupOffset, maxDocsPerGroup); |
| if (collector.maxScoreCollector == null) { |
| groupMaxScore = Float.NaN; |
| } else { |
| groupMaxScore = collector.maxScoreCollector.getMaxScore(); |
| } |
| } |
| |
| groupDocsResult[groupIDX++] = new GroupDocs<>(Float.NaN, |
| groupMaxScore, |
| topDocs.totalHits, |
| topDocs.scoreDocs, |
| group.groupValue, |
| group.sortValues); |
| maxScore = Math.max(maxScore, groupMaxScore); |
| } |
| |
| return new TopGroups<>(groupSort.getSort(), |
| withinGroupSort.getSort(), |
| totalHitCount, totalGroupedHitCount, groupDocsResult, |
| maxScore); |
| } |
| |
| |
| } |