| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.grouping; |
| |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.Sort; |
| import org.apache.lucene.search.SortField; |
| import org.apache.lucene.search.TopDocs; |
| import org.apache.lucene.search.TopFieldDocs; |
| import org.apache.lucene.search.TotalHits; |
| import org.apache.lucene.search.TotalHits.Relation; |
| |
| /** Represents result returned by a grouping search. |
| * |
| * @lucene.experimental */ |
| public class TopGroups<T> { |
| /** Number of documents matching the search */ |
| public final long totalHitCount; |
| |
| /** Number of documents grouped into the topN groups */ |
| public final long totalGroupedHitCount; |
| |
| /** The total number of unique groups. If <code>null</code> this value is not computed. */ |
| public final Long totalGroupCount; |
| |
| /** Group results in groupSort order */ |
| public final GroupDocs<T>[] groups; |
| |
| /** How groups are sorted against each other */ |
| public final SortField[] groupSort; |
| |
| /** How docs are sorted within each group */ |
| public final SortField[] withinGroupSort; |
| |
| /** Highest score across all hits, or |
| * <code>Float.NaN</code> if scores were not computed. */ |
| public final float maxScore; |
| |
| public TopGroups(SortField[] groupSort, SortField[] withinGroupSort, long totalHitCount, long totalGroupedHitCount, GroupDocs<T>[] groups, float maxScore) { |
| this.groupSort = groupSort; |
| this.withinGroupSort = withinGroupSort; |
| this.totalHitCount = totalHitCount; |
| this.totalGroupedHitCount = totalGroupedHitCount; |
| this.groups = groups; |
| this.totalGroupCount = null; |
| this.maxScore = maxScore; |
| } |
| |
| public TopGroups(TopGroups<T> oldTopGroups, Long totalGroupCount) { |
| this.groupSort = oldTopGroups.groupSort; |
| this.withinGroupSort = oldTopGroups.withinGroupSort; |
| this.totalHitCount = oldTopGroups.totalHitCount; |
| this.totalGroupedHitCount = oldTopGroups.totalGroupedHitCount; |
| this.groups = oldTopGroups.groups; |
| this.maxScore = oldTopGroups.maxScore; |
| this.totalGroupCount = totalGroupCount; |
| } |
| |
| /** How the GroupDocs score (if any) should be merged. */ |
| public enum ScoreMergeMode { |
| /** Set score to Float.NaN */ |
| None, |
| /* Sum score across all shards for this group. */ |
| Total, |
| /* Avg score across all shards for this group. */ |
| Avg, |
| } |
| |
| /** |
| * If either value is NaN then return the other value, otherwise |
| * return the greater of the two values by calling Math.max. |
| * @param a - one value |
| * @param b - another value |
| * @return ignoring any NaN return the greater of a and b |
| */ |
| private static float nonNANmax(float a, float b) { |
| if (Float.isNaN(a)) return b; |
| if (Float.isNaN(b)) return a; |
| return Math.max(a, b); |
| } |
| |
| /** Merges an array of TopGroups, for example obtained |
| * from the second-pass collector across multiple |
| * shards. Each TopGroups must have been sorted by the |
| * same groupSort and docSort, and the top groups passed |
| * to all second-pass collectors must be the same. |
| * |
| * <b>NOTE</b>: We can't always compute an exact totalGroupCount. |
| * Documents belonging to a group may occur on more than |
| * one shard and thus the merged totalGroupCount can be |
| * higher than the actual totalGroupCount. In this case the |
| * totalGroupCount represents a upper bound. If the documents |
| * of one group do only reside in one shard then the |
| * totalGroupCount is exact. |
| * |
| * <b>NOTE</b>: the topDocs in each GroupDocs is actually |
| * an instance of TopDocsAndShards |
| */ |
| public static <T> TopGroups<T> merge(TopGroups<T>[] shardGroups, Sort groupSort, Sort docSort, int docOffset, int docTopN, ScoreMergeMode scoreMergeMode) { |
| |
| //System.out.println("TopGroups.merge"); |
| |
| if (shardGroups.length == 0) { |
| return null; |
| } |
| |
| long totalHitCount = 0; |
| long totalGroupedHitCount = 0; |
| // Optionally merge the totalGroupCount. |
| Long totalGroupCount = null; |
| |
| final int numGroups = shardGroups[0].groups.length; |
| for(TopGroups<T> shard : shardGroups) { |
| if (numGroups != shard.groups.length) { |
| throw new IllegalArgumentException("number of groups differs across shards; you must pass same top groups to all shards' second-pass collector"); |
| } |
| totalHitCount += shard.totalHitCount; |
| totalGroupedHitCount += shard.totalGroupedHitCount; |
| if (shard.totalGroupCount != null) { |
| if (totalGroupCount == null) { |
| totalGroupCount = 0L; |
| } |
| |
| totalGroupCount += shard.totalGroupCount; |
| } |
| } |
| |
| @SuppressWarnings({"unchecked","rawtypes"}) |
| final GroupDocs<T>[] mergedGroupDocs = new GroupDocs[numGroups]; |
| |
| final TopDocs[] shardTopDocs; |
| if (docSort.equals(Sort.RELEVANCE)) { |
| shardTopDocs = new TopDocs[shardGroups.length]; |
| } else { |
| shardTopDocs = new TopFieldDocs[shardGroups.length]; |
| } |
| float totalMaxScore = Float.NaN; |
| |
| for(int groupIDX=0;groupIDX<numGroups;groupIDX++) { |
| final T groupValue = shardGroups[0].groups[groupIDX].groupValue; |
| //System.out.println(" merge groupValue=" + groupValue + " sortValues=" + Arrays.toString(shardGroups[0].groups[groupIDX].groupSortValues)); |
| float maxScore = Float.NaN; |
| long totalHits = 0; |
| double scoreSum = 0.0; |
| for(int shardIDX=0;shardIDX<shardGroups.length;shardIDX++) { |
| //System.out.println(" shard=" + shardIDX); |
| final TopGroups<T> shard = shardGroups[shardIDX]; |
| final GroupDocs<?> shardGroupDocs = shard.groups[groupIDX]; |
| if (groupValue == null) { |
| if (shardGroupDocs.groupValue != null) { |
| throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector"); |
| } |
| } else if (!groupValue.equals(shardGroupDocs.groupValue)) { |
| throw new IllegalArgumentException("group values differ across shards; you must pass same top groups to all shards' second-pass collector"); |
| } |
| |
| /* |
| for(ScoreDoc sd : shardGroupDocs.scoreDocs) { |
| System.out.println(" doc=" + sd.doc); |
| } |
| */ |
| |
| if (docSort.equals(Sort.RELEVANCE)) { |
| shardTopDocs[shardIDX] = new TopDocs(shardGroupDocs.totalHits, |
| shardGroupDocs.scoreDocs); |
| } else { |
| shardTopDocs[shardIDX] = new TopFieldDocs(shardGroupDocs.totalHits, |
| shardGroupDocs.scoreDocs, |
| docSort.getSort()); |
| } |
| |
| for (int i = 0; i < shardTopDocs[shardIDX].scoreDocs.length; i++) { |
| shardTopDocs[shardIDX].scoreDocs[i].shardIndex = shardIDX; |
| } |
| |
| maxScore = nonNANmax(maxScore, shardGroupDocs.maxScore); |
| assert shardGroupDocs.totalHits.relation == Relation.EQUAL_TO; |
| totalHits += shardGroupDocs.totalHits.value; |
| scoreSum += shardGroupDocs.score; |
| } |
| |
| final TopDocs mergedTopDocs; |
| if (docSort.equals(Sort.RELEVANCE)) { |
| mergedTopDocs = TopDocs.merge(docOffset + docTopN, shardTopDocs); |
| } else { |
| mergedTopDocs = TopDocs.merge(docSort, docOffset + docTopN, (TopFieldDocs[]) shardTopDocs); |
| } |
| |
| // Slice; |
| final ScoreDoc[] mergedScoreDocs; |
| if (docOffset == 0) { |
| mergedScoreDocs = mergedTopDocs.scoreDocs; |
| } else if (docOffset >= mergedTopDocs.scoreDocs.length) { |
| mergedScoreDocs = new ScoreDoc[0]; |
| } else { |
| mergedScoreDocs = new ScoreDoc[mergedTopDocs.scoreDocs.length - docOffset]; |
| System.arraycopy(mergedTopDocs.scoreDocs, |
| docOffset, |
| mergedScoreDocs, |
| 0, |
| mergedTopDocs.scoreDocs.length - docOffset); |
| } |
| |
| final float groupScore; |
| switch(scoreMergeMode) { |
| case None: |
| groupScore = Float.NaN; |
| break; |
| case Avg: |
| if (totalHits > 0) { |
| groupScore = (float) (scoreSum / totalHits); |
| } else { |
| groupScore = Float.NaN; |
| } |
| break; |
| case Total: |
| groupScore = (float) scoreSum; |
| break; |
| default: |
| throw new IllegalArgumentException("can't handle ScoreMergeMode " + scoreMergeMode); |
| } |
| |
| //System.out.println("SHARDS=" + Arrays.toString(mergedTopDocs.shardIndex)); |
| mergedGroupDocs[groupIDX] = new GroupDocs<>(groupScore, |
| maxScore, |
| new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), |
| mergedScoreDocs, |
| groupValue, |
| shardGroups[0].groups[groupIDX].groupSortValues); |
| totalMaxScore = nonNANmax(totalMaxScore, maxScore); |
| } |
| |
| if (totalGroupCount != null) { |
| TopGroups<T> result = new TopGroups<>(groupSort.getSort(), |
| docSort.getSort(), |
| totalHitCount, |
| totalGroupedHitCount, |
| mergedGroupDocs, |
| totalMaxScore); |
| return new TopGroups<>(result, totalGroupCount); |
| } else { |
| return new TopGroups<>(groupSort.getSort(), |
| docSort.getSort(), |
| totalHitCount, |
| totalGroupedHitCount, |
| mergedGroupDocs, |
| totalMaxScore); |
| } |
| } |
| } |