blob: 9fb778dc832966ab283e5d1ed23a0f2675b9f641 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.CollectionUtil;
/**
* A conjunction of DocIdSetIterators. Requires that all of its sub-iterators must be on the same
* document all the time. This iterates over the doc ids that are present in each given
* DocIdSetIterator. <br>
* Public only for use in {@link org.apache.lucene.search.spans}.
*
* @lucene.internal
*/
public final class ConjunctionDISI extends DocIdSetIterator {
/**
* Create a conjunction over the provided {@link Scorer}s. Note that the returned {@link
* DocIdSetIterator} might leverage two-phase iteration in which case it is possible to retrieve
* the {@link TwoPhaseIterator} using {@link TwoPhaseIterator#unwrap}.
*/
public static DocIdSetIterator intersectScorers(Collection<Scorer> scorers) {
if (scorers.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (Scorer scorer : scorers) {
addScorer(scorer, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
/**
* Create a conjunction over the provided DocIdSetIterators. Note that the returned {@link
* DocIdSetIterator} might leverage two-phase iteration in which case it is possible to retrieve
* the {@link TwoPhaseIterator} using {@link TwoPhaseIterator#unwrap}.
*/
public static DocIdSetIterator intersectIterators(List<DocIdSetIterator> iterators) {
if (iterators.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (DocIdSetIterator iterator : iterators) {
addIterator(iterator, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
/**
* Create a conjunction over the provided {@link Spans}. Note that the returned {@link
* DocIdSetIterator} might leverage two-phase iteration in which case it is possible to retrieve
* the {@link TwoPhaseIterator} using {@link TwoPhaseIterator#unwrap}.
*/
public static DocIdSetIterator intersectSpans(List<Spans> spanList) {
if (spanList.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
}
final List<DocIdSetIterator> allIterators = new ArrayList<>();
final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>();
for (Spans spans : spanList) {
addSpans(spans, allIterators, twoPhaseIterators);
}
return createConjunction(allIterators, twoPhaseIterators);
}
/**
* Adds the scorer, possibly splitting up into two phases or collapsing if it is another
* conjunction
*/
private static void addScorer(
Scorer scorer,
List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhaseIter = scorer.twoPhaseIterator();
if (twoPhaseIter != null) {
addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators);
} else { // no approximation support, use the iterator as-is
addIterator(scorer.iterator(), allIterators, twoPhaseIterators);
}
}
/** Adds the Spans. */
private static void addSpans(
Spans spans, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhaseIter = spans.asTwoPhaseIterator();
if (twoPhaseIter != null) {
addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators);
} else { // no approximation support, use the iterator as-is
addIterator(spans, allIterators, twoPhaseIterators);
}
}
private static void addIterator(
DocIdSetIterator disi,
List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) {
TwoPhaseIterator twoPhase = TwoPhaseIterator.unwrap(disi);
if (twoPhase != null) {
addTwoPhaseIterator(twoPhase, allIterators, twoPhaseIterators);
} else if (disi.getClass()
== ConjunctionDISI.class) { // Check for exactly this class for collapsing
ConjunctionDISI conjunction = (ConjunctionDISI) disi;
// subconjuctions have already split themselves into two phase iterators and others, so we can
// take those
// iterators as they are and move them up to this conjunction
allIterators.add(conjunction.lead1);
allIterators.add(conjunction.lead2);
Collections.addAll(allIterators, conjunction.others);
} else if (disi.getClass() == BitSetConjunctionDISI.class) {
BitSetConjunctionDISI conjunction = (BitSetConjunctionDISI) disi;
allIterators.add(conjunction.lead);
Collections.addAll(allIterators, conjunction.bitSetIterators);
} else {
allIterators.add(disi);
}
}
private static void addTwoPhaseIterator(
TwoPhaseIterator twoPhaseIter,
List<DocIdSetIterator> allIterators,
List<TwoPhaseIterator> twoPhaseIterators) {
addIterator(twoPhaseIter.approximation(), allIterators, twoPhaseIterators);
if (twoPhaseIter.getClass()
== ConjunctionTwoPhaseIterator.class) { // Check for exactly this class for collapsing
Collections.addAll(
twoPhaseIterators, ((ConjunctionTwoPhaseIterator) twoPhaseIter).twoPhaseIterators);
} else {
twoPhaseIterators.add(twoPhaseIter);
}
}
private static DocIdSetIterator createConjunction(
List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) {
// check that all sub-iterators are on the same doc ID
int curDoc =
allIterators.size() > 0
? allIterators.get(0).docID()
: twoPhaseIterators.get(0).approximation.docID();
boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc);
iteratorsOnTheSameDoc =
iteratorsOnTheSameDoc
&& twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc);
if (iteratorsOnTheSameDoc == false) {
throw new IllegalArgumentException(
"Sub-iterators of ConjunctionDISI are not on the same document!");
}
long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List<BitSetIterator> bitSetIterators = new ArrayList<>();
List<DocIdSetIterator> iterators = new ArrayList<>();
for (DocIdSetIterator iterator : allIterators) {
if (iterator.cost() > minCost && iterator instanceof BitSetIterator) {
// we put all bitset iterators into bitSetIterators
// except if they have the minimum cost, since we need
// them to lead the iteration in that case
bitSetIterators.add((BitSetIterator) iterator);
} else {
iterators.add(iterator);
}
}
DocIdSetIterator disi;
if (iterators.size() == 1) {
disi = iterators.get(0);
} else {
disi = new ConjunctionDISI(iterators);
}
if (bitSetIterators.size() > 0) {
disi = new BitSetConjunctionDISI(disi, bitSetIterators);
}
if (twoPhaseIterators.isEmpty() == false) {
disi =
TwoPhaseIterator.asDocIdSetIterator(
new ConjunctionTwoPhaseIterator(disi, twoPhaseIterators));
}
return disi;
}
final DocIdSetIterator lead1, lead2;
final DocIdSetIterator[] others;
private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) {
assert iterators.size() >= 2;
// Sort the array the first time to allow the least frequent DocsEnum to
// lead the matching.
CollectionUtil.timSort(
iterators,
new Comparator<DocIdSetIterator>() {
@Override
public int compare(DocIdSetIterator o1, DocIdSetIterator o2) {
return Long.compare(o1.cost(), o2.cost());
}
});
lead1 = iterators.get(0);
lead2 = iterators.get(1);
others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]);
}
private int doNext(int doc) throws IOException {
advanceHead:
for (; ; ) {
assert doc == lead1.docID();
// find agreement between the two iterators with the lower costs
// we special case them because they do not need the
// 'other.docID() < doc' check that the 'others' iterators need
final int next2 = lead2.advance(doc);
if (next2 != doc) {
doc = lead1.advance(next2);
if (next2 != doc) {
continue;
}
}
// then find agreement with other iterators
for (DocIdSetIterator other : others) {
// other.doc may already be equal to doc if we "continued advanceHead"
// on the previous iteration and the advance on the lead scorer exactly matched.
if (other.docID() < doc) {
final int next = other.advance(doc);
if (next > doc) {
// iterator beyond the current doc - advance lead and continue to the new highest doc.
doc = lead1.advance(next);
continue advanceHead;
}
}
}
// success - all iterators are on the same doc
return doc;
}
}
@Override
public int advance(int target) throws IOException {
assert assertItersOnSameDoc()
: "Sub-iterators of ConjunctionDISI are not one the same document!";
return doNext(lead1.advance(target));
}
@Override
public int docID() {
return lead1.docID();
}
@Override
public int nextDoc() throws IOException {
assert assertItersOnSameDoc()
: "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead1.nextDoc());
}
@Override
public long cost() {
return lead1.cost(); // overestimate
}
// Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise
private boolean assertItersOnSameDoc() {
int curDoc = lead1.docID();
boolean iteratorsOnTheSameDoc = (lead2.docID() == curDoc);
for (int i = 0; (i < others.length && iteratorsOnTheSameDoc); i++) {
iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (others[i].docID() == curDoc);
}
return iteratorsOnTheSameDoc;
}
/** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */
private static class BitSetConjunctionDISI extends DocIdSetIterator {
private final DocIdSetIterator lead;
private final BitSetIterator[] bitSetIterators;
private final BitSet[] bitSets;
private final int minLength;
BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) {
this.lead = lead;
assert bitSetIterators.size() > 0;
this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]);
// Put the least costly iterators first so that we exit as soon as possible
ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost()));
this.bitSets = new BitSet[this.bitSetIterators.length];
int minLen = Integer.MAX_VALUE;
for (int i = 0; i < this.bitSetIterators.length; ++i) {
BitSet bitSet = this.bitSetIterators[i].getBitSet();
this.bitSets[i] = bitSet;
minLen = Math.min(minLen, bitSet.length());
}
this.minLength = minLen;
}
@Override
public int docID() {
return lead.docID();
}
@Override
public int nextDoc() throws IOException {
assert assertItersOnSameDoc()
: "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead.nextDoc());
}
@Override
public int advance(int target) throws IOException {
assert assertItersOnSameDoc()
: "Sub-iterators of ConjunctionDISI are not on the same document!";
return doNext(lead.advance(target));
}
private int doNext(int doc) throws IOException {
advanceLead:
for (; ; doc = lead.nextDoc()) {
if (doc >= minLength) {
return NO_MORE_DOCS;
}
for (BitSet bitSet : bitSets) {
if (bitSet.get(doc) == false) {
continue advanceLead;
}
}
for (BitSetIterator iterator : bitSetIterators) {
iterator.setDocId(doc);
}
return doc;
}
}
@Override
public long cost() {
return lead.cost();
}
// Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise
private boolean assertItersOnSameDoc() {
int curDoc = lead.docID();
boolean iteratorsOnTheSameDoc = true;
for (int i = 0; (i < bitSetIterators.length && iteratorsOnTheSameDoc); i++) {
iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (bitSetIterators[i].docID() == curDoc);
}
return iteratorsOnTheSameDoc;
}
}
/** {@link TwoPhaseIterator} implementing a conjunction. */
private static final class ConjunctionTwoPhaseIterator extends TwoPhaseIterator {
private final TwoPhaseIterator[] twoPhaseIterators;
private final float matchCost;
private ConjunctionTwoPhaseIterator(
DocIdSetIterator approximation, List<? extends TwoPhaseIterator> twoPhaseIterators) {
super(approximation);
assert twoPhaseIterators.size() > 0;
CollectionUtil.timSort(
twoPhaseIterators,
new Comparator<TwoPhaseIterator>() {
@Override
public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) {
return Float.compare(o1.matchCost(), o2.matchCost());
}
});
this.twoPhaseIterators =
twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]);
// Compute the matchCost as the total matchCost of the sub iterators.
// TODO: This could be too high because the matching is done cheapest first: give the lower
// matchCosts a higher weight.
float totalMatchCost = 0;
for (TwoPhaseIterator tpi : twoPhaseIterators) {
totalMatchCost += tpi.matchCost();
}
matchCost = totalMatchCost;
}
@Override
public boolean matches() throws IOException {
for (TwoPhaseIterator twoPhaseIterator : twoPhaseIterators) { // match cheapest first
if (twoPhaseIterator.matches() == false) {
return false;
}
}
return true;
}
@Override
public float matchCost() {
return matchCost;
}
}
}