| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.grouping; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.NavigableSet; |
| import java.util.TreeSet; |
| |
| import org.apache.lucene.search.FieldComparator; |
| import org.apache.lucene.search.Sort; |
| import org.apache.lucene.search.SortField; |
| |
| /** |
| * Represents a group that is found during the first pass search. |
| * |
| * @lucene.experimental |
| */ |
| public class SearchGroup<T> { |
| |
| /** The value that defines this group */ |
| public T groupValue; |
| |
| /** The sort values used during sorting. These are the |
| * groupSort field values of the highest rank document |
| * (by the groupSort) within the group. Can be |
| * <code>null</code> if <code>fillFields=false</code> had |
| * been passed to {@link FirstPassGroupingCollector#getTopGroups} */ |
| public Object[] sortValues; |
| |
| @Override |
| public String toString() { |
| return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")"); |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| if (this == o) return true; |
| if (o == null || getClass() != o.getClass()) return false; |
| |
| SearchGroup<?> that = (SearchGroup<?>) o; |
| |
| if (groupValue == null) { |
| if (that.groupValue != null) { |
| return false; |
| } |
| } else if (!groupValue.equals(that.groupValue)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| @Override |
| public int hashCode() { |
| return groupValue != null ? groupValue.hashCode() : 0; |
| } |
| |
| private static class ShardIter<T> { |
| public final Iterator<SearchGroup<T>> iter; |
| public final int shardIndex; |
| |
| public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) { |
| this.shardIndex = shardIndex; |
| iter = shard.iterator(); |
| assert iter.hasNext(); |
| } |
| |
| public SearchGroup<T> next() { |
| assert iter.hasNext(); |
| final SearchGroup<T> group = iter.next(); |
| if (group.sortValues == null) { |
| throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector"); |
| } |
| return group; |
| } |
| |
| @Override |
| public String toString() { |
| return "ShardIter(shard=" + shardIndex + ")"; |
| } |
| } |
| |
| // Holds all shards currently on the same group |
| private static class MergedGroup<T> { |
| |
| // groupValue may be null! |
| public final T groupValue; |
| |
| public Object[] topValues; |
| public final List<ShardIter<T>> shards = new ArrayList<>(); |
| public int minShardIndex; |
| public boolean processed; |
| public boolean inQueue; |
| |
| public MergedGroup(T groupValue) { |
| this.groupValue = groupValue; |
| } |
| |
| // Only for assert |
| private boolean neverEquals(Object _other) { |
| if (_other instanceof MergedGroup) { |
| MergedGroup<?> other = (MergedGroup<?>) _other; |
| if (groupValue == null) { |
| assert other.groupValue != null; |
| } else { |
| assert !groupValue.equals(other.groupValue); |
| } |
| } |
| return true; |
| } |
| |
| @Override |
| public boolean equals(Object _other) { |
| // We never have another MergedGroup instance with |
| // same groupValue |
| assert neverEquals(_other); |
| |
| if (_other instanceof MergedGroup) { |
| MergedGroup<?> other = (MergedGroup<?>) _other; |
| if (groupValue == null) { |
| return other == null; |
| } else { |
| return groupValue.equals(other); |
| } |
| } else { |
| return false; |
| } |
| } |
| |
| @Override |
| public int hashCode() { |
| if (groupValue == null) { |
| return 0; |
| } else { |
| return groupValue.hashCode(); |
| } |
| } |
| } |
| |
| private static class GroupComparator<T> implements Comparator<MergedGroup<T>> { |
| |
| @SuppressWarnings("rawtypes") |
| public final FieldComparator[] comparators; |
| |
| public final int[] reversed; |
| |
| @SuppressWarnings({"unchecked", "rawtypes"}) |
| public GroupComparator(Sort groupSort) { |
| final SortField[] sortFields = groupSort.getSort(); |
| comparators = new FieldComparator[sortFields.length]; |
| reversed = new int[sortFields.length]; |
| for (int compIDX = 0; compIDX < sortFields.length; compIDX++) { |
| final SortField sortField = sortFields[compIDX]; |
| comparators[compIDX] = sortField.getComparator(1, compIDX); |
| reversed[compIDX] = sortField.getReverse() ? -1 : 1; |
| } |
| } |
| |
| @Override |
| @SuppressWarnings({"unchecked","rawtypes"}) |
| public int compare(MergedGroup<T> group, MergedGroup<T> other) { |
| if (group == other) { |
| return 0; |
| } |
| //System.out.println("compare group=" + group + " other=" + other); |
| final Object[] groupValues = group.topValues; |
| final Object[] otherValues = other.topValues; |
| //System.out.println(" groupValues=" + groupValues + " otherValues=" + otherValues); |
| for (int compIDX = 0;compIDX < comparators.length; compIDX++) { |
| final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX], |
| otherValues[compIDX]); |
| if (c != 0) { |
| return c; |
| } |
| } |
| |
| // Tie break by min shard index: |
| assert group.minShardIndex != other.minShardIndex; |
| return group.minShardIndex - other.minShardIndex; |
| } |
| } |
| |
| private static class GroupMerger<T> { |
| |
| private final GroupComparator<T> groupComp; |
| private final NavigableSet<MergedGroup<T>> queue; |
| private final Map<T,MergedGroup<T>> groupsSeen; |
| |
| public GroupMerger(Sort groupSort) { |
| groupComp = new GroupComparator<>(groupSort); |
| queue = new TreeSet<>(groupComp); |
| groupsSeen = new HashMap<>(); |
| } |
| |
| @SuppressWarnings({"unchecked","rawtypes"}) |
| private void updateNextGroup(int topN, ShardIter<T> shard) { |
| while(shard.iter.hasNext()) { |
| final SearchGroup<T> group = shard.next(); |
| MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue); |
| final boolean isNew = mergedGroup == null; |
| //System.out.println(" next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues)); |
| |
| if (isNew) { |
| // Start a new group: |
| //System.out.println(" new"); |
| mergedGroup = new MergedGroup<>(group.groupValue); |
| mergedGroup.minShardIndex = shard.shardIndex; |
| assert group.sortValues != null; |
| mergedGroup.topValues = group.sortValues; |
| groupsSeen.put(group.groupValue, mergedGroup); |
| mergedGroup.inQueue = true; |
| queue.add(mergedGroup); |
| } else if (mergedGroup.processed) { |
| // This shard produced a group that we already |
| // processed; move on to next group... |
| continue; |
| } else { |
| //System.out.println(" old"); |
| boolean competes = false; |
| for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) { |
| final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX], |
| mergedGroup.topValues[compIDX]); |
| if (cmp < 0) { |
| // Definitely competes |
| competes = true; |
| break; |
| } else if (cmp > 0) { |
| // Definitely does not compete |
| break; |
| } else if (compIDX == groupComp.comparators.length-1) { |
| if (shard.shardIndex < mergedGroup.minShardIndex) { |
| competes = true; |
| } |
| } |
| } |
| |
| //System.out.println(" competes=" + competes); |
| |
| if (competes) { |
| // Group's sort changed -- remove & re-insert |
| if (mergedGroup.inQueue) { |
| queue.remove(mergedGroup); |
| } |
| mergedGroup.topValues = group.sortValues; |
| mergedGroup.minShardIndex = shard.shardIndex; |
| queue.add(mergedGroup); |
| mergedGroup.inQueue = true; |
| } |
| } |
| |
| mergedGroup.shards.add(shard); |
| break; |
| } |
| |
| // Prune un-competitive groups: |
| while(queue.size() > topN) { |
| final MergedGroup<T> group = queue.pollLast(); |
| //System.out.println("PRUNE: " + group); |
| group.inQueue = false; |
| } |
| } |
| |
| public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) { |
| |
| final int maxQueueSize = offset + topN; |
| |
| //System.out.println("merge"); |
| // Init queue: |
| for(int shardIDX=0;shardIDX<shards.size();shardIDX++) { |
| final Collection<SearchGroup<T>> shard = shards.get(shardIDX); |
| if (!shard.isEmpty()) { |
| //System.out.println(" insert shard=" + shardIDX); |
| updateNextGroup(maxQueueSize, new ShardIter<>(shard, shardIDX)); |
| } |
| } |
| |
| // Pull merged topN groups: |
| final List<SearchGroup<T>> newTopGroups = new ArrayList<>(topN); |
| |
| int count = 0; |
| |
| while(!queue.isEmpty()) { |
| final MergedGroup<T> group = queue.pollFirst(); |
| group.processed = true; |
| //System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues)); |
| if (count++ >= offset) { |
| final SearchGroup<T> newGroup = new SearchGroup<>(); |
| newGroup.groupValue = group.groupValue; |
| newGroup.sortValues = group.topValues; |
| newTopGroups.add(newGroup); |
| if (newTopGroups.size() == topN) { |
| break; |
| } |
| //} else { |
| // System.out.println(" skip < offset"); |
| } |
| |
| // Advance all iters in this group: |
| for(ShardIter<T> shardIter : group.shards) { |
| updateNextGroup(maxQueueSize, shardIter); |
| } |
| } |
| |
| if (newTopGroups.isEmpty()) { |
| return null; |
| } else { |
| return newTopGroups; |
| } |
| } |
| } |
| |
| /** Merges multiple collections of top groups, for example |
| * obtained from separate index shards. The provided |
| * groupSort must match how the groups were sorted, and |
| * the provided SearchGroups must have been computed |
| * with fillFields=true passed to {@link |
| * FirstPassGroupingCollector#getTopGroups}. |
| * |
| * <p>NOTE: this returns null if the topGroups is empty. |
| */ |
| public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort) { |
| if (topGroups.isEmpty()) { |
| return null; |
| } else { |
| return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN); |
| } |
| } |
| } |