blob: 58e1f74350a5da3523c79829aa019d99a126f7e0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.grouping;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.TreeSet;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
/**
* Represents a group that is found during the first pass search.
*
* @lucene.experimental
*/
public class SearchGroup<T> {
/** The value that defines this group */
public T groupValue;
/** The sort values used during sorting. These are the
* groupSort field values of the highest rank document
* (by the groupSort) within the group. Can be
* <code>null</code> if <code>fillFields=false</code> had
* been passed to {@link FirstPassGroupingCollector#getTopGroups} */
public Object[] sortValues;
@Override
public String toString() {
return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")");
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SearchGroup<?> that = (SearchGroup<?>) o;
if (groupValue == null) {
if (that.groupValue != null) {
return false;
}
} else if (!groupValue.equals(that.groupValue)) {
return false;
}
return true;
}
@Override
public int hashCode() {
return groupValue != null ? groupValue.hashCode() : 0;
}
private static class ShardIter<T> {
public final Iterator<SearchGroup<T>> iter;
public final int shardIndex;
public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) {
this.shardIndex = shardIndex;
iter = shard.iterator();
assert iter.hasNext();
}
public SearchGroup<T> next() {
assert iter.hasNext();
final SearchGroup<T> group = iter.next();
if (group.sortValues == null) {
throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector");
}
return group;
}
@Override
public String toString() {
return "ShardIter(shard=" + shardIndex + ")";
}
}
// Holds all shards currently on the same group
private static class MergedGroup<T> {
// groupValue may be null!
public final T groupValue;
public Object[] topValues;
public final List<ShardIter<T>> shards = new ArrayList<>();
public int minShardIndex;
public boolean processed;
public boolean inQueue;
public MergedGroup(T groupValue) {
this.groupValue = groupValue;
}
// Only for assert
private boolean neverEquals(Object _other) {
if (_other instanceof MergedGroup) {
MergedGroup<?> other = (MergedGroup<?>) _other;
if (groupValue == null) {
assert other.groupValue != null;
} else {
assert !groupValue.equals(other.groupValue);
}
}
return true;
}
@Override
public boolean equals(Object _other) {
// We never have another MergedGroup instance with
// same groupValue
assert neverEquals(_other);
if (_other instanceof MergedGroup) {
MergedGroup<?> other = (MergedGroup<?>) _other;
if (groupValue == null) {
return other == null;
} else {
return groupValue.equals(other);
}
} else {
return false;
}
}
@Override
public int hashCode() {
if (groupValue == null) {
return 0;
} else {
return groupValue.hashCode();
}
}
}
private static class GroupComparator<T> implements Comparator<MergedGroup<T>> {
@SuppressWarnings("rawtypes")
public final FieldComparator[] comparators;
public final int[] reversed;
@SuppressWarnings({"unchecked", "rawtypes"})
public GroupComparator(Sort groupSort) {
final SortField[] sortFields = groupSort.getSort();
comparators = new FieldComparator[sortFields.length];
reversed = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, compIDX);
reversed[compIDX] = sortField.getReverse() ? -1 : 1;
}
}
@Override
@SuppressWarnings({"unchecked","rawtypes"})
public int compare(MergedGroup<T> group, MergedGroup<T> other) {
if (group == other) {
return 0;
}
//System.out.println("compare group=" + group + " other=" + other);
final Object[] groupValues = group.topValues;
final Object[] otherValues = other.topValues;
//System.out.println(" groupValues=" + groupValues + " otherValues=" + otherValues);
for (int compIDX = 0;compIDX < comparators.length; compIDX++) {
final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX],
otherValues[compIDX]);
if (c != 0) {
return c;
}
}
// Tie break by min shard index:
assert group.minShardIndex != other.minShardIndex;
return group.minShardIndex - other.minShardIndex;
}
}
private static class GroupMerger<T> {
private final GroupComparator<T> groupComp;
private final NavigableSet<MergedGroup<T>> queue;
private final Map<T,MergedGroup<T>> groupsSeen;
public GroupMerger(Sort groupSort) {
groupComp = new GroupComparator<>(groupSort);
queue = new TreeSet<>(groupComp);
groupsSeen = new HashMap<>();
}
@SuppressWarnings({"unchecked","rawtypes"})
private void updateNextGroup(int topN, ShardIter<T> shard) {
while(shard.iter.hasNext()) {
final SearchGroup<T> group = shard.next();
MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue);
final boolean isNew = mergedGroup == null;
//System.out.println(" next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));
if (isNew) {
// Start a new group:
//System.out.println(" new");
mergedGroup = new MergedGroup<>(group.groupValue);
mergedGroup.minShardIndex = shard.shardIndex;
assert group.sortValues != null;
mergedGroup.topValues = group.sortValues;
groupsSeen.put(group.groupValue, mergedGroup);
mergedGroup.inQueue = true;
queue.add(mergedGroup);
} else if (mergedGroup.processed) {
// This shard produced a group that we already
// processed; move on to next group...
continue;
} else {
//System.out.println(" old");
boolean competes = false;
for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) {
final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX],
mergedGroup.topValues[compIDX]);
if (cmp < 0) {
// Definitely competes
competes = true;
break;
} else if (cmp > 0) {
// Definitely does not compete
break;
} else if (compIDX == groupComp.comparators.length-1) {
if (shard.shardIndex < mergedGroup.minShardIndex) {
competes = true;
}
}
}
//System.out.println(" competes=" + competes);
if (competes) {
// Group's sort changed -- remove & re-insert
if (mergedGroup.inQueue) {
queue.remove(mergedGroup);
}
mergedGroup.topValues = group.sortValues;
mergedGroup.minShardIndex = shard.shardIndex;
queue.add(mergedGroup);
mergedGroup.inQueue = true;
}
}
mergedGroup.shards.add(shard);
break;
}
// Prune un-competitive groups:
while(queue.size() > topN) {
final MergedGroup<T> group = queue.pollLast();
//System.out.println("PRUNE: " + group);
group.inQueue = false;
}
}
public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) {
final int maxQueueSize = offset + topN;
//System.out.println("merge");
// Init queue:
for(int shardIDX=0;shardIDX<shards.size();shardIDX++) {
final Collection<SearchGroup<T>> shard = shards.get(shardIDX);
if (!shard.isEmpty()) {
//System.out.println(" insert shard=" + shardIDX);
updateNextGroup(maxQueueSize, new ShardIter<>(shard, shardIDX));
}
}
// Pull merged topN groups:
final List<SearchGroup<T>> newTopGroups = new ArrayList<>(topN);
int count = 0;
while(!queue.isEmpty()) {
final MergedGroup<T> group = queue.pollFirst();
group.processed = true;
//System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues));
if (count++ >= offset) {
final SearchGroup<T> newGroup = new SearchGroup<>();
newGroup.groupValue = group.groupValue;
newGroup.sortValues = group.topValues;
newTopGroups.add(newGroup);
if (newTopGroups.size() == topN) {
break;
}
//} else {
// System.out.println(" skip < offset");
}
// Advance all iters in this group:
for(ShardIter<T> shardIter : group.shards) {
updateNextGroup(maxQueueSize, shardIter);
}
}
if (newTopGroups.isEmpty()) {
return null;
} else {
return newTopGroups;
}
}
}
/** Merges multiple collections of top groups, for example
* obtained from separate index shards. The provided
* groupSort must match how the groups were sorted, and
* the provided SearchGroups must have been computed
* with fillFields=true passed to {@link
* FirstPassGroupingCollector#getTopGroups}.
*
* <p>NOTE: this returns null if the topGroups is empty.
*/
public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort) {
if (topGroups.isEmpty()) {
return null;
} else {
return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN);
}
}
}