blob: 7183aacfae4c43239b4585e7289f9d6cc8814cf3 [file] [log] [blame]
package org.apache.lucene.search.grouping.term;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.grouping.AbstractAllGroupHeadsCollector;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SentinelIntSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A base implementation of {@link org.apache.lucene.search.grouping.AbstractAllGroupHeadsCollector} for retrieving the most relevant groups when grouping
* on a string based group field. More specifically this all concrete implementations of this base implementation
* use {@link org.apache.lucene.index.SortedDocValues}.
*
* @lucene.experimental
*/
public abstract class TermAllGroupHeadsCollector<GH extends AbstractAllGroupHeadsCollector.GroupHead<?>> extends AbstractAllGroupHeadsCollector<GH> {
private static final int DEFAULT_INITIAL_SIZE = 128;
final String groupField;
final BytesRef scratchBytesRef = new BytesRef();
SortedDocValues groupIndex;
AtomicReaderContext readerContext;
protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) {
super(numberOfSorts);
this.groupField = groupField;
}
/**
* Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments.
* This factory method decides with implementation is best suited.
*
* Delegates to {@link #create(String, org.apache.lucene.search.Sort, int)} with an initialSize of 128.
*
* @param groupField The field to group by
* @param sortWithinGroup The sort within each group
* @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments
*/
public static AbstractAllGroupHeadsCollector<?> create(String groupField, Sort sortWithinGroup) {
return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE);
}
/**
* Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments.
* This factory method decides with implementation is best suited.
*
* @param groupField The field to group by
* @param sortWithinGroup The sort within each group
* @param initialSize The initial allocation size of the internal int set and group list which should roughly match
* the total number of expected unique groups. Be aware that the heap usage is
* 4 bytes * initialSize.
* @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments
*/
public static AbstractAllGroupHeadsCollector<?> create(String groupField, Sort sortWithinGroup, int initialSize) {
boolean sortAllScore = true;
boolean sortAllFieldValue = true;
for (SortField sortField : sortWithinGroup.getSort()) {
if (sortField.getType() == SortField.Type.SCORE) {
sortAllFieldValue = false;
} else if (needGeneralImpl(sortField)) {
return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup);
} else {
sortAllScore = false;
}
}
if (sortAllScore) {
return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
} else if (sortAllFieldValue) {
return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
} else {
return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize);
}
}
// Returns when a sort field needs the general impl.
private static boolean needGeneralImpl(SortField sortField) {
SortField.Type sortType = sortField.getType();
// Note (MvG): We can also make an optimized impl when sorting is SortField.DOC
return sortType != SortField.Type.STRING_VAL && sortType != SortField.Type.STRING && sortType != SortField.Type.SCORE;
}
// A general impl that works for any group sort.
static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector<GeneralAllGroupHeadsCollector.GroupHead> {
private final Sort sortWithinGroup;
private final Map<BytesRef, GroupHead> groups;
private Scorer scorer;
GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) {
super(groupField, sortWithinGroup.getSort().length);
this.sortWithinGroup = sortWithinGroup;
groups = new HashMap<>();
final SortField[] sortFields = sortWithinGroup.getSort();
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
}
}
@Override
protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
final int ord = groupIndex.getOrd(doc);
final BytesRef groupValue;
if (ord == -1) {
groupValue = null;
} else {
groupIndex.lookupOrd(ord, scratchBytesRef);
groupValue = scratchBytesRef;
}
GroupHead groupHead = groups.get(groupValue);
if (groupHead == null) {
groupHead = new GroupHead(groupValue, sortWithinGroup, doc);
groups.put(groupValue == null ? null : BytesRef.deepCopyOf(groupValue), groupHead);
temporalResult.stop = true;
} else {
temporalResult.stop = false;
}
temporalResult.groupHead = groupHead;
}
@Override
protected Collection<GroupHead> getCollectedGroupHeads() {
return groups.values();
}
@Override
protected void doSetNextReader(AtomicReaderContext context) throws IOException {
this.readerContext = context;
groupIndex = DocValues.getSorted(context.reader(), groupField);
for (GroupHead groupHead : groups.values()) {
for (int i = 0; i < groupHead.comparators.length; i++) {
groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context);
}
}
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
for (GroupHead groupHead : groups.values()) {
for (FieldComparator<?> comparator : groupHead.comparators) {
comparator.setScorer(scorer);
}
}
}
class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> {
final FieldComparator<?>[] comparators;
@SuppressWarnings({"unchecked","rawtypes"})
private GroupHead(BytesRef groupValue, Sort sort, int doc) throws IOException {
super(groupValue, doc + readerContext.docBase);
final SortField[] sortFields = sort.getSort();
comparators = new FieldComparator[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext);
comparators[i].setScorer(scorer);
comparators[i].copy(0, doc);
comparators[i].setBottom(0);
}
}
@Override
public int compare(int compIDX, int doc) throws IOException {
return comparators[compIDX].compareBottom(doc);
}
@Override
public void updateDocHead(int doc) throws IOException {
for (FieldComparator<?> comparator : comparators) {
comparator.copy(0, doc);
comparator.setBottom(0);
}
this.doc = doc + readerContext.docBase;
}
}
}
// AbstractAllGroupHeadsCollector optimized for ord fields and scores.
static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdScoreAllGroupHeadsCollector.GroupHead> {
private final SentinelIntSet ordSet;
private final List<GroupHead> collectedGroups;
private final SortField[] fields;
private SortedDocValues[] sortsIndex;
private Scorer scorer;
private GroupHead[] segmentGroupHeads;
OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
super(groupField, sortWithinGroup.getSort().length);
ordSet = new SentinelIntSet(initialSize, -2);
collectedGroups = new ArrayList<>(initialSize);
final SortField[] sortFields = sortWithinGroup.getSort();
fields = new SortField[sortFields.length];
sortsIndex = new SortedDocValues[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
fields[i] = sortFields[i];
}
}
@Override
protected Collection<GroupHead> getCollectedGroupHeads() {
return collectedGroups;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
@Override
protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
int key = groupIndex.getOrd(doc);
GroupHead groupHead;
if (!ordSet.exists(key)) {
ordSet.put(key);
BytesRef term;
if (key == -1) {
term = null;
} else {
term = new BytesRef();
groupIndex.lookupOrd(key, term);
}
groupHead = new GroupHead(doc, term);
collectedGroups.add(groupHead);
segmentGroupHeads[key+1] = groupHead;
temporalResult.stop = true;
} else {
temporalResult.stop = false;
groupHead = segmentGroupHeads[key+1];
}
temporalResult.groupHead = groupHead;
}
@Override
protected void doSetNextReader(AtomicReaderContext context) throws IOException {
this.readerContext = context;
groupIndex = DocValues.getSorted(context.reader(), groupField);
for (int i = 0; i < fields.length; i++) {
if (fields[i].getType() == SortField.Type.SCORE) {
continue;
}
sortsIndex[i] = DocValues.getSorted(context.reader(), fields[i].getField());
}
// Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
ordSet.clear();
segmentGroupHeads = new GroupHead[groupIndex.getValueCount()+1];
for (GroupHead collectedGroup : collectedGroups) {
int ord;
if (collectedGroup.groupValue == null) {
ord = -1;
} else {
ord = groupIndex.lookupTerm(collectedGroup.groupValue);
}
if (collectedGroup.groupValue == null || ord >= 0) {
ordSet.put(ord);
segmentGroupHeads[ord+1] = collectedGroup;
for (int i = 0; i < sortsIndex.length; i++) {
if (fields[i].getType() == SortField.Type.SCORE) {
continue;
}
int sortOrd;
if (collectedGroup.sortValues[i] == null) {
sortOrd = -1;
} else {
sortOrd = sortsIndex[i].lookupTerm(collectedGroup.sortValues[i]);
}
collectedGroup.sortOrds[i] = sortOrd;
}
}
}
}
class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> {
BytesRef[] sortValues;
int[] sortOrds;
float[] scores;
private GroupHead(int doc, BytesRef groupValue) throws IOException {
super(groupValue, doc + readerContext.docBase);
sortValues = new BytesRef[sortsIndex.length];
sortOrds = new int[sortsIndex.length];
scores = new float[sortsIndex.length];
for (int i = 0; i < sortsIndex.length; i++) {
if (fields[i].getType() == SortField.Type.SCORE) {
scores[i] = scorer.score();
} else {
sortOrds[i] = sortsIndex[i].getOrd(doc);
sortValues[i] = new BytesRef();
if (sortOrds[i] != -1) {
sortsIndex[i].get(doc, sortValues[i]);
}
}
}
}
@Override
public int compare(int compIDX, int doc) throws IOException {
if (fields[compIDX].getType() == SortField.Type.SCORE) {
float score = scorer.score();
if (scores[compIDX] < score) {
return 1;
} else if (scores[compIDX] > score) {
return -1;
}
return 0;
} else {
if (sortOrds[compIDX] < 0) {
// The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative.
if (sortsIndex[compIDX].getOrd(doc) == -1) {
scratchBytesRef.length = 0;
} else {
sortsIndex[compIDX].get(doc, scratchBytesRef);
}
return sortValues[compIDX].compareTo(scratchBytesRef);
} else {
return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc);
}
}
}
@Override
public void updateDocHead(int doc) throws IOException {
for (int i = 0; i < sortsIndex.length; i++) {
if (fields[i].getType() == SortField.Type.SCORE) {
scores[i] = scorer.score();
} else {
sortOrds[i] = sortsIndex[i].getOrd(doc);
if (sortOrds[i] == -1) {
sortValues[i].length = 0;
} else {
sortsIndex[i].get(doc, sortValues[i]);
}
}
}
this.doc = doc + readerContext.docBase;
}
}
}
// AbstractAllGroupHeadsCollector optimized for ord fields.
static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdAllGroupHeadsCollector.GroupHead> {
private final SentinelIntSet ordSet;
private final List<GroupHead> collectedGroups;
private final SortField[] fields;
private SortedDocValues[] sortsIndex;
private GroupHead[] segmentGroupHeads;
OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
super(groupField, sortWithinGroup.getSort().length);
ordSet = new SentinelIntSet(initialSize, -2);
collectedGroups = new ArrayList<>(initialSize);
final SortField[] sortFields = sortWithinGroup.getSort();
fields = new SortField[sortFields.length];
sortsIndex = new SortedDocValues[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
fields[i] = sortFields[i];
}
}
@Override
protected Collection<GroupHead> getCollectedGroupHeads() {
return collectedGroups;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
}
@Override
protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
int key = groupIndex.getOrd(doc);
GroupHead groupHead;
if (!ordSet.exists(key)) {
ordSet.put(key);
BytesRef term;
if (key == -1) {
term = null;
} else {
term = new BytesRef();
groupIndex.lookupOrd(key, term);
}
groupHead = new GroupHead(doc, term);
collectedGroups.add(groupHead);
segmentGroupHeads[key+1] = groupHead;
temporalResult.stop = true;
} else {
temporalResult.stop = false;
groupHead = segmentGroupHeads[key+1];
}
temporalResult.groupHead = groupHead;
}
@Override
protected void doSetNextReader(AtomicReaderContext context) throws IOException {
this.readerContext = context;
groupIndex = DocValues.getSorted(context.reader(), groupField);
for (int i = 0; i < fields.length; i++) {
sortsIndex[i] = DocValues.getSorted(context.reader(), fields[i].getField());
}
// Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
ordSet.clear();
segmentGroupHeads = new GroupHead[groupIndex.getValueCount()+1];
for (GroupHead collectedGroup : collectedGroups) {
int groupOrd;
if (collectedGroup.groupValue == null) {
groupOrd = -1;
} else {
groupOrd = groupIndex.lookupTerm(collectedGroup.groupValue);
}
if (collectedGroup.groupValue == null || groupOrd >= 0) {
ordSet.put(groupOrd);
segmentGroupHeads[groupOrd+1] = collectedGroup;
for (int i = 0; i < sortsIndex.length; i++) {
int sortOrd;
if (collectedGroup.sortOrds[i] == -1) {
sortOrd = -1;
} else {
sortOrd = sortsIndex[i].lookupTerm(collectedGroup.sortValues[i]);
}
collectedGroup.sortOrds[i] = sortOrd;
}
}
}
}
class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> {
BytesRef[] sortValues;
int[] sortOrds;
private GroupHead(int doc, BytesRef groupValue) {
super(groupValue, doc + readerContext.docBase);
sortValues = new BytesRef[sortsIndex.length];
sortOrds = new int[sortsIndex.length];
for (int i = 0; i < sortsIndex.length; i++) {
sortOrds[i] = sortsIndex[i].getOrd(doc);
sortValues[i] = new BytesRef();
if (sortOrds[i] != -1) {
sortsIndex[i].get(doc, sortValues[i]);
}
}
}
@Override
public int compare(int compIDX, int doc) throws IOException {
if (sortOrds[compIDX] < 0) {
// The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative.
if (sortsIndex[compIDX].getOrd(doc) == -1) {
scratchBytesRef.length = 0;
} else {
sortsIndex[compIDX].get(doc, scratchBytesRef);
}
return sortValues[compIDX].compareTo(scratchBytesRef);
} else {
return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc);
}
}
@Override
public void updateDocHead(int doc) throws IOException {
for (int i = 0; i < sortsIndex.length; i++) {
sortOrds[i] = sortsIndex[i].getOrd(doc);
if (sortOrds[i] == -1) {
sortValues[i].length = 0;
} else {
sortsIndex[i].lookupOrd(sortOrds[i], sortValues[i]);
}
}
this.doc = doc + readerContext.docBase;
}
}
}
// AbstractAllGroupHeadsCollector optimized for scores.
static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<ScoreAllGroupHeadsCollector.GroupHead> {
private final SentinelIntSet ordSet;
private final List<GroupHead> collectedGroups;
private final SortField[] fields;
private Scorer scorer;
private GroupHead[] segmentGroupHeads;
ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) {
super(groupField, sortWithinGroup.getSort().length);
ordSet = new SentinelIntSet(initialSize, -2);
collectedGroups = new ArrayList<>(initialSize);
final SortField[] sortFields = sortWithinGroup.getSort();
fields = new SortField[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
reversed[i] = sortFields[i].getReverse() ? -1 : 1;
fields[i] = sortFields[i];
}
}
@Override
protected Collection<GroupHead> getCollectedGroupHeads() {
return collectedGroups;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
@Override
protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException {
int key = groupIndex.getOrd(doc);
GroupHead groupHead;
if (!ordSet.exists(key)) {
ordSet.put(key);
BytesRef term;
if (key == -1) {
term = null;
} else {
term = new BytesRef();
groupIndex.lookupOrd(key, term);
}
groupHead = new GroupHead(doc, term);
collectedGroups.add(groupHead);
segmentGroupHeads[key+1] = groupHead;
temporalResult.stop = true;
} else {
temporalResult.stop = false;
groupHead = segmentGroupHeads[key+1];
}
temporalResult.groupHead = groupHead;
}
@Override
protected void doSetNextReader(AtomicReaderContext context) throws IOException {
this.readerContext = context;
groupIndex = DocValues.getSorted(context.reader(), groupField);
// Clear ordSet and fill it with previous encountered groups that can occur in the current segment.
ordSet.clear();
segmentGroupHeads = new GroupHead[groupIndex.getValueCount()+1];
for (GroupHead collectedGroup : collectedGroups) {
int ord;
if (collectedGroup.groupValue == null) {
ord = -1;
} else {
ord = groupIndex.lookupTerm(collectedGroup.groupValue);
}
if (collectedGroup.groupValue == null || ord >= 0) {
ordSet.put(ord);
segmentGroupHeads[ord+1] = collectedGroup;
}
}
}
class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> {
float[] scores;
private GroupHead(int doc, BytesRef groupValue) throws IOException {
super(groupValue, doc + readerContext.docBase);
scores = new float[fields.length];
float score = scorer.score();
for (int i = 0; i < scores.length; i++) {
scores[i] = score;
}
}
@Override
public int compare(int compIDX, int doc) throws IOException {
float score = scorer.score();
if (scores[compIDX] < score) {
return 1;
} else if (scores[compIDX] > score) {
return -1;
}
return 0;
}
@Override
public void updateDocHead(int doc) throws IOException {
float score = scorer.score();
for (int i = 0; i < scores.length; i++) {
scores[i] = score;
}
this.doc = doc + readerContext.docBase;
}
}
}
}