blob: 853411ed8966670d29827e2122a7f9380385cdf3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.util.BitDocIdSet;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
public class TestConjunctionDISI extends LuceneTestCase {
private static TwoPhaseIterator approximation(DocIdSetIterator iterator, final FixedBitSet confirmed) {
DocIdSetIterator approximation;
if (random().nextBoolean()) {
approximation = anonymizeIterator(iterator);
} else {
approximation = iterator;
}
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() {
return confirmed.get(approximation.docID());
}
@Override
public float matchCost() {
return 5; // #operations in FixedBitSet#get()
}
};
}
/** Return an anonym class so that ConjunctionDISI cannot optimize it
* like it does eg. for BitSetIterators. */
private static DocIdSetIterator anonymizeIterator(DocIdSetIterator it) {
return new DocIdSetIterator() {
@Override
public int nextDoc() throws IOException {
return it.nextDoc();
}
@Override
public int docID() {
return it.docID();
}
@Override
public long cost() {
return it.docID();
}
@Override
public int advance(int target) throws IOException {
return it.advance(target);
}
};
}
private static Scorer scorer(TwoPhaseIterator twoPhaseIterator) {
return scorer(TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator), twoPhaseIterator);
}
private static class FakeWeight extends Weight {
protected FakeWeight() {
super(new MatchNoDocsQuery());
}
@Override
public void extractTerms(Set<Term> terms) {
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return null;
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
/**
* Create a {@link Scorer} that wraps the given {@link DocIdSetIterator}. It
* also accepts a {@link TwoPhaseIterator} view, which is exposed in
* {@link Scorer#twoPhaseIterator()}. When the two-phase view is not null,
* then {@link DocIdSetIterator#nextDoc()} and {@link DocIdSetIterator#advance(int)} will raise
* an exception in order to make sure that {@link ConjunctionDISI} takes
* advantage of the {@link TwoPhaseIterator} view.
*/
private static Scorer scorer(DocIdSetIterator it, TwoPhaseIterator twoPhaseIterator) {
return new Scorer(new FakeWeight()) {
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return it.docID();
}
@Override
public int nextDoc() throws IOException {
if (twoPhaseIterator != null) {
throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator");
}
return it.nextDoc();
}
@Override
public int advance(int target) throws IOException {
if (twoPhaseIterator != null) {
throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator");
}
return it.advance(target);
}
@Override
public long cost() {
if (twoPhaseIterator != null) {
throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator");
}
return it.cost();
}
};
}
@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhaseIterator;
}
@Override
public int docID() {
if (twoPhaseIterator != null) {
throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator");
}
return it.docID();
}
@Override
public float score() throws IOException {
return 0;
}
@Override
public float getMaxScore(int upTo) throws IOException {
return 0;
}
};
}
private static FixedBitSet randomSet(int maxDoc) {
final int step = TestUtil.nextInt(random(), 1, 10);
FixedBitSet set = new FixedBitSet(maxDoc);
for (int doc = random().nextInt(step); doc < maxDoc; doc += TestUtil.nextInt(random(), 1, step)) {
set.set(doc);
}
return set;
}
private static FixedBitSet clearRandomBits(FixedBitSet other) {
final FixedBitSet set = new FixedBitSet(other.length());
set.or(other);
for (int i = 0; i < set.length(); ++i) {
if (random().nextBoolean()) {
set.clear(i);
}
}
return set;
}
private static FixedBitSet intersect(FixedBitSet[] bitSets) {
final FixedBitSet intersection = new FixedBitSet(bitSets[0].length());
intersection.or(bitSets[0]);
for (int i = 1; i < bitSets.length; ++i) {
intersection.and(bitSets[i]);
}
return intersection;
}
private static FixedBitSet toBitSet(int maxDoc, DocIdSetIterator iterator) throws IOException {
final FixedBitSet set = new FixedBitSet(maxDoc);
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
set.set(doc);
}
return set;
}
// Test that the conjunction iterator is correct
public void testConjunction() throws IOException {
final int iters = atLeast(100);
for (int iter = 0; iter < iters; ++iter) {
final int maxDoc = TestUtil.nextInt(random(), 100, 10000);
final int numIterators = TestUtil.nextInt(random(), 2, 5);
final FixedBitSet[] sets = new FixedBitSet[numIterators];
final Scorer[] iterators = new Scorer[numIterators];
for (int i = 0; i < iterators.length; ++i) {
final FixedBitSet set = randomSet(maxDoc);
switch (random().nextInt(3)) {
case 0:
// simple iterator
sets[i] = set;
iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, anonymizeIterator(new BitDocIdSet(set).iterator()));
break;
case 1:
// bitSet iterator
sets[i] = set;
iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator());
break;
default:
// scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
iterators[i] = scorer(approximation);
break;
}
}
final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(Arrays.asList(iterators));
assertEquals(intersect(sets), toBitSet(maxDoc, conjunction));
}
}
// Test that the conjunction approximation is correct
public void testConjunctionApproximation() throws IOException {
final int iters = atLeast(100);
for (int iter = 0; iter < iters; ++iter) {
final int maxDoc = TestUtil.nextInt(random(), 100, 10000);
final int numIterators = TestUtil.nextInt(random(), 2, 5);
final FixedBitSet[] sets = new FixedBitSet[numIterators];
final Scorer[] iterators = new Scorer[numIterators];
boolean hasApproximation = false;
for (int i = 0; i < iterators.length; ++i) {
final FixedBitSet set = randomSet(maxDoc);
if (random().nextBoolean()) {
// simple iterator
sets[i] = set;
iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.COMPLETE_NO_SCORES, new BitDocIdSet(set).iterator());
} else {
// scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
iterators[i] = scorer(approximation);
hasApproximation = true;
}
}
final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(Arrays.asList(iterators));
TwoPhaseIterator twoPhaseIterator = TwoPhaseIterator.unwrap(conjunction);
assertEquals(hasApproximation, twoPhaseIterator != null);
if (hasApproximation) {
assertEquals(intersect(sets), toBitSet(maxDoc, TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator)));
}
}
}
// This test makes sure that when nesting scorers with ConjunctionDISI, confirmations are pushed to the root.
public void testRecursiveConjunctionApproximation() throws IOException {
final int iters = atLeast(100);
for (int iter = 0; iter < iters; ++iter) {
final int maxDoc = TestUtil.nextInt(random(), 100, 10000);
final int numIterators = TestUtil.nextInt(random(), 2, 5);
final FixedBitSet[] sets = new FixedBitSet[numIterators];
Scorer conjunction = null;
boolean hasApproximation = false;
for (int i = 0; i < numIterators; ++i) {
final FixedBitSet set = randomSet(maxDoc);
final Scorer newIterator;
switch (random().nextInt(3)) {
case 0:
// simple iterator
sets[i] = set;
newIterator = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, anonymizeIterator(new BitDocIdSet(set).iterator()));
break;
case 1:
// bitSet iterator
sets[i] = set;
newIterator = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator());
break;
default:
// scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
newIterator = scorer(approximation);
hasApproximation = true;
break;
}
if (conjunction == null) {
conjunction = newIterator;
} else {
final DocIdSetIterator conj = ConjunctionDISI.intersectScorers(Arrays.asList(conjunction, newIterator));
conjunction = scorer(conj, TwoPhaseIterator.unwrap(conj));
}
}
TwoPhaseIterator twoPhaseIterator = conjunction.twoPhaseIterator();
assertEquals(hasApproximation, twoPhaseIterator != null);
if (hasApproximation) {
assertEquals(intersect(sets), toBitSet(maxDoc, TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator)));
} else {
assertEquals(intersect(sets), toBitSet(maxDoc, conjunction.iterator()));
}
}
}
public void testCollapseSubConjunctions(boolean wrapWithScorer) throws IOException {
final int iters = atLeast(100);
for (int iter = 0; iter < iters; ++iter) {
final int maxDoc = TestUtil.nextInt(random(), 100, 10000);
final int numIterators = TestUtil.nextInt(random(), 5, 10);
final FixedBitSet[] sets = new FixedBitSet[numIterators];
final List<Scorer> scorers = new LinkedList<>();
for (int i = 0; i < numIterators; ++i) {
final FixedBitSet set = randomSet(maxDoc);
if (random().nextBoolean()) {
// simple iterator
sets[i] = set;
scorers.add(new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator()));
} else {
// scorer with approximation
final FixedBitSet confirmed = clearRandomBits(set);
sets[i] = confirmed;
final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed);
scorers.add(scorer(approximation));
}
}
// make some sub sequences into sub conjunctions
final int subIters = atLeast(3);
for (int subIter = 0; subIter < subIters && scorers.size() > 3; ++subIter) {
final int subSeqStart = TestUtil.nextInt(random(), 0, scorers.size() - 2);
final int subSeqEnd = TestUtil.nextInt(random(), subSeqStart + 2, scorers.size());
List<Scorer> subIterators = scorers.subList(subSeqStart, subSeqEnd);
Scorer subConjunction;
if (wrapWithScorer) {
subConjunction = new ConjunctionScorer(new FakeWeight(), subIterators, Collections.emptyList());
} else {
subConjunction = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, ConjunctionDISI.intersectScorers(subIterators));
}
scorers.set(subSeqStart, subConjunction);
int toRemove = subSeqEnd - subSeqStart - 1;
while (toRemove-- > 0) {
scorers.remove(subSeqStart + 1);
}
}
if (scorers.size() == 1) {
// ConjunctionDISI needs two iterators
scorers.add(new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, DocIdSetIterator.all(maxDoc)));
}
final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(scorers);
assertEquals(intersect(sets), toBitSet(maxDoc, conjunction));
}
}
public void testCollapseSubConjunctionDISIs() throws IOException {
testCollapseSubConjunctions(false);
}
public void testCollapseSubConjunctionScorers() throws IOException {
testCollapseSubConjunctions(true);
}
public void testIllegalAdvancementOfSubIteratorsTripsAssertion() throws IOException {
assumeTrue("Assertions must be enabled for this test!", LuceneTestCase.assertsAreEnabled);
int maxDoc = 100;
final int numIterators = TestUtil.nextInt(random(), 2, 5);
FixedBitSet set = randomSet(maxDoc);
DocIdSetIterator[] iterators = new DocIdSetIterator[numIterators];
for (int i = 0; i < iterators.length; ++i) {
iterators[i] = new BitDocIdSet(set).iterator();
}
final DocIdSetIterator conjunction = ConjunctionDISI.intersectIterators(Arrays.asList(iterators));
int idx = TestUtil.nextInt(random() , 0, iterators.length-1);
iterators[idx].nextDoc(); // illegally advancing one of the sub-iterators outside of the conjunction iterator
AssertionError ex = expectThrows(AssertionError.class, () -> conjunction.nextDoc());
assertEquals("Sub-iterators of ConjunctionDISI are not on the same document!", ex.getMessage());
}
}