| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| |
| import java.io.IOException; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Set; |
| |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.Term; |
| import org.apache.lucene.util.BitDocIdSet; |
| import org.apache.lucene.util.FixedBitSet; |
| import org.apache.lucene.util.LuceneTestCase; |
| import org.apache.lucene.util.TestUtil; |
| |
| public class TestConjunctionDISI extends LuceneTestCase { |
| |
| private static TwoPhaseIterator approximation(DocIdSetIterator iterator, final FixedBitSet confirmed) { |
| DocIdSetIterator approximation; |
| if (random().nextBoolean()) { |
| approximation = anonymizeIterator(iterator); |
| } else { |
| approximation = iterator; |
| } |
| return new TwoPhaseIterator(approximation) { |
| |
| @Override |
| public boolean matches() { |
| return confirmed.get(approximation.docID()); |
| } |
| |
| @Override |
| public float matchCost() { |
| return 5; // #operations in FixedBitSet#get() |
| } |
| }; |
| } |
| |
| /** Return an anonym class so that ConjunctionDISI cannot optimize it |
| * like it does eg. for BitSetIterators. */ |
| private static DocIdSetIterator anonymizeIterator(DocIdSetIterator it) { |
| return new DocIdSetIterator() { |
| |
| @Override |
| public int nextDoc() throws IOException { |
| return it.nextDoc(); |
| } |
| |
| @Override |
| public int docID() { |
| return it.docID(); |
| } |
| |
| @Override |
| public long cost() { |
| return it.docID(); |
| } |
| |
| @Override |
| public int advance(int target) throws IOException { |
| return it.advance(target); |
| } |
| }; |
| } |
| |
| private static Scorer scorer(TwoPhaseIterator twoPhaseIterator) { |
| return scorer(TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator), twoPhaseIterator); |
| } |
| |
| private static class FakeWeight extends Weight { |
| |
| protected FakeWeight() { |
| super(new MatchNoDocsQuery()); |
| } |
| |
| @Override |
| public void extractTerms(Set<Term> terms) { |
| |
| } |
| |
| @Override |
| public Explanation explain(LeafReaderContext context, int doc) throws IOException { |
| return null; |
| } |
| |
| @Override |
| public Scorer scorer(LeafReaderContext context) throws IOException { |
| return null; |
| } |
| |
| @Override |
| public boolean isCacheable(LeafReaderContext ctx) { |
| return false; |
| } |
| } |
| |
| /** |
| * Create a {@link Scorer} that wraps the given {@link DocIdSetIterator}. It |
| * also accepts a {@link TwoPhaseIterator} view, which is exposed in |
| * {@link Scorer#twoPhaseIterator()}. When the two-phase view is not null, |
| * then {@link DocIdSetIterator#nextDoc()} and {@link DocIdSetIterator#advance(int)} will raise |
| * an exception in order to make sure that {@link ConjunctionDISI} takes |
| * advantage of the {@link TwoPhaseIterator} view. |
| */ |
| private static Scorer scorer(DocIdSetIterator it, TwoPhaseIterator twoPhaseIterator) { |
| return new Scorer(new FakeWeight()) { |
| |
| @Override |
| public DocIdSetIterator iterator() { |
| return new DocIdSetIterator() { |
| |
| @Override |
| public int docID() { |
| return it.docID(); |
| } |
| |
| @Override |
| public int nextDoc() throws IOException { |
| if (twoPhaseIterator != null) { |
| throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator"); |
| } |
| return it.nextDoc(); |
| } |
| |
| @Override |
| public int advance(int target) throws IOException { |
| if (twoPhaseIterator != null) { |
| throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator"); |
| } |
| return it.advance(target); |
| } |
| |
| @Override |
| public long cost() { |
| if (twoPhaseIterator != null) { |
| throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator"); |
| } |
| return it.cost(); |
| } |
| }; |
| } |
| |
| @Override |
| public TwoPhaseIterator twoPhaseIterator() { |
| return twoPhaseIterator; |
| } |
| |
| @Override |
| public int docID() { |
| if (twoPhaseIterator != null) { |
| throw new UnsupportedOperationException("ConjunctionDISI should call the two-phase iterator"); |
| } |
| return it.docID(); |
| } |
| |
| @Override |
| public float score() throws IOException { |
| return 0; |
| } |
| |
| @Override |
| public float getMaxScore(int upTo) throws IOException { |
| return 0; |
| } |
| }; |
| } |
| |
| private static FixedBitSet randomSet(int maxDoc) { |
| final int step = TestUtil.nextInt(random(), 1, 10); |
| FixedBitSet set = new FixedBitSet(maxDoc); |
| for (int doc = random().nextInt(step); doc < maxDoc; doc += TestUtil.nextInt(random(), 1, step)) { |
| set.set(doc); |
| } |
| return set; |
| } |
| |
| private static FixedBitSet clearRandomBits(FixedBitSet other) { |
| final FixedBitSet set = new FixedBitSet(other.length()); |
| set.or(other); |
| for (int i = 0; i < set.length(); ++i) { |
| if (random().nextBoolean()) { |
| set.clear(i); |
| } |
| } |
| return set; |
| } |
| |
| private static FixedBitSet intersect(FixedBitSet[] bitSets) { |
| final FixedBitSet intersection = new FixedBitSet(bitSets[0].length()); |
| intersection.or(bitSets[0]); |
| for (int i = 1; i < bitSets.length; ++i) { |
| intersection.and(bitSets[i]); |
| } |
| return intersection; |
| } |
| |
| private static FixedBitSet toBitSet(int maxDoc, DocIdSetIterator iterator) throws IOException { |
| final FixedBitSet set = new FixedBitSet(maxDoc); |
| for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { |
| set.set(doc); |
| } |
| return set; |
| } |
| |
| // Test that the conjunction iterator is correct |
| public void testConjunction() throws IOException { |
| final int iters = atLeast(100); |
| for (int iter = 0; iter < iters; ++iter) { |
| final int maxDoc = TestUtil.nextInt(random(), 100, 10000); |
| final int numIterators = TestUtil.nextInt(random(), 2, 5); |
| final FixedBitSet[] sets = new FixedBitSet[numIterators]; |
| final Scorer[] iterators = new Scorer[numIterators]; |
| for (int i = 0; i < iterators.length; ++i) { |
| final FixedBitSet set = randomSet(maxDoc); |
| switch (random().nextInt(3)) { |
| case 0: |
| // simple iterator |
| sets[i] = set; |
| iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, anonymizeIterator(new BitDocIdSet(set).iterator())); |
| break; |
| case 1: |
| // bitSet iterator |
| sets[i] = set; |
| iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator()); |
| break; |
| default: |
| // scorer with approximation |
| final FixedBitSet confirmed = clearRandomBits(set); |
| sets[i] = confirmed; |
| final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); |
| iterators[i] = scorer(approximation); |
| break; |
| } |
| } |
| |
| final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(Arrays.asList(iterators)); |
| assertEquals(intersect(sets), toBitSet(maxDoc, conjunction)); |
| } |
| } |
| |
| // Test that the conjunction approximation is correct |
| public void testConjunctionApproximation() throws IOException { |
| final int iters = atLeast(100); |
| for (int iter = 0; iter < iters; ++iter) { |
| final int maxDoc = TestUtil.nextInt(random(), 100, 10000); |
| final int numIterators = TestUtil.nextInt(random(), 2, 5); |
| final FixedBitSet[] sets = new FixedBitSet[numIterators]; |
| final Scorer[] iterators = new Scorer[numIterators]; |
| boolean hasApproximation = false; |
| for (int i = 0; i < iterators.length; ++i) { |
| final FixedBitSet set = randomSet(maxDoc); |
| if (random().nextBoolean()) { |
| // simple iterator |
| sets[i] = set; |
| iterators[i] = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.COMPLETE_NO_SCORES, new BitDocIdSet(set).iterator()); |
| } else { |
| // scorer with approximation |
| final FixedBitSet confirmed = clearRandomBits(set); |
| sets[i] = confirmed; |
| final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); |
| iterators[i] = scorer(approximation); |
| hasApproximation = true; |
| } |
| } |
| |
| final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(Arrays.asList(iterators)); |
| TwoPhaseIterator twoPhaseIterator = TwoPhaseIterator.unwrap(conjunction); |
| assertEquals(hasApproximation, twoPhaseIterator != null); |
| if (hasApproximation) { |
| assertEquals(intersect(sets), toBitSet(maxDoc, TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator))); |
| } |
| } |
| } |
| |
| // This test makes sure that when nesting scorers with ConjunctionDISI, confirmations are pushed to the root. |
| public void testRecursiveConjunctionApproximation() throws IOException { |
| final int iters = atLeast(100); |
| for (int iter = 0; iter < iters; ++iter) { |
| final int maxDoc = TestUtil.nextInt(random(), 100, 10000); |
| final int numIterators = TestUtil.nextInt(random(), 2, 5); |
| final FixedBitSet[] sets = new FixedBitSet[numIterators]; |
| Scorer conjunction = null; |
| boolean hasApproximation = false; |
| for (int i = 0; i < numIterators; ++i) { |
| final FixedBitSet set = randomSet(maxDoc); |
| final Scorer newIterator; |
| switch (random().nextInt(3)) { |
| case 0: |
| // simple iterator |
| sets[i] = set; |
| newIterator = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, anonymizeIterator(new BitDocIdSet(set).iterator())); |
| break; |
| case 1: |
| // bitSet iterator |
| sets[i] = set; |
| newIterator = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator()); |
| break; |
| default: |
| // scorer with approximation |
| final FixedBitSet confirmed = clearRandomBits(set); |
| sets[i] = confirmed; |
| final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); |
| newIterator = scorer(approximation); |
| hasApproximation = true; |
| break; |
| } |
| if (conjunction == null) { |
| conjunction = newIterator; |
| } else { |
| final DocIdSetIterator conj = ConjunctionDISI.intersectScorers(Arrays.asList(conjunction, newIterator)); |
| conjunction = scorer(conj, TwoPhaseIterator.unwrap(conj)); |
| } |
| } |
| |
| TwoPhaseIterator twoPhaseIterator = conjunction.twoPhaseIterator(); |
| assertEquals(hasApproximation, twoPhaseIterator != null); |
| if (hasApproximation) { |
| assertEquals(intersect(sets), toBitSet(maxDoc, TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator))); |
| } else { |
| assertEquals(intersect(sets), toBitSet(maxDoc, conjunction.iterator())); |
| } |
| } |
| } |
| |
| public void testCollapseSubConjunctions(boolean wrapWithScorer) throws IOException { |
| final int iters = atLeast(100); |
| for (int iter = 0; iter < iters; ++iter) { |
| final int maxDoc = TestUtil.nextInt(random(), 100, 10000); |
| final int numIterators = TestUtil.nextInt(random(), 5, 10); |
| final FixedBitSet[] sets = new FixedBitSet[numIterators]; |
| final List<Scorer> scorers = new LinkedList<>(); |
| for (int i = 0; i < numIterators; ++i) { |
| final FixedBitSet set = randomSet(maxDoc); |
| if (random().nextBoolean()) { |
| // simple iterator |
| sets[i] = set; |
| scorers.add(new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, new BitDocIdSet(set).iterator())); |
| } else { |
| // scorer with approximation |
| final FixedBitSet confirmed = clearRandomBits(set); |
| sets[i] = confirmed; |
| final TwoPhaseIterator approximation = approximation(new BitDocIdSet(set).iterator(), confirmed); |
| scorers.add(scorer(approximation)); |
| } |
| } |
| |
| // make some sub sequences into sub conjunctions |
| final int subIters = atLeast(3); |
| for (int subIter = 0; subIter < subIters && scorers.size() > 3; ++subIter) { |
| final int subSeqStart = TestUtil.nextInt(random(), 0, scorers.size() - 2); |
| final int subSeqEnd = TestUtil.nextInt(random(), subSeqStart + 2, scorers.size()); |
| List<Scorer> subIterators = scorers.subList(subSeqStart, subSeqEnd); |
| Scorer subConjunction; |
| if (wrapWithScorer) { |
| subConjunction = new ConjunctionScorer(new FakeWeight(), subIterators, Collections.emptyList()); |
| } else { |
| subConjunction = new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, ConjunctionDISI.intersectScorers(subIterators)); |
| } |
| scorers.set(subSeqStart, subConjunction); |
| int toRemove = subSeqEnd - subSeqStart - 1; |
| while (toRemove-- > 0) { |
| scorers.remove(subSeqStart + 1); |
| } |
| } |
| if (scorers.size() == 1) { |
| // ConjunctionDISI needs two iterators |
| scorers.add(new ConstantScoreScorer(new FakeWeight(), 0f, ScoreMode.TOP_SCORES, DocIdSetIterator.all(maxDoc))); |
| } |
| |
| |
| final DocIdSetIterator conjunction = ConjunctionDISI.intersectScorers(scorers); |
| assertEquals(intersect(sets), toBitSet(maxDoc, conjunction)); |
| } |
| } |
| |
| public void testCollapseSubConjunctionDISIs() throws IOException { |
| testCollapseSubConjunctions(false); |
| } |
| |
| public void testCollapseSubConjunctionScorers() throws IOException { |
| testCollapseSubConjunctions(true); |
| } |
| |
| public void testIllegalAdvancementOfSubIteratorsTripsAssertion() throws IOException { |
| assumeTrue("Assertions must be enabled for this test!", LuceneTestCase.assertsAreEnabled); |
| int maxDoc = 100; |
| final int numIterators = TestUtil.nextInt(random(), 2, 5); |
| FixedBitSet set = randomSet(maxDoc); |
| |
| DocIdSetIterator[] iterators = new DocIdSetIterator[numIterators]; |
| for (int i = 0; i < iterators.length; ++i) { |
| iterators[i] = new BitDocIdSet(set).iterator(); |
| } |
| final DocIdSetIterator conjunction = ConjunctionDISI.intersectIterators(Arrays.asList(iterators)); |
| int idx = TestUtil.nextInt(random() , 0, iterators.length-1); |
| iterators[idx].nextDoc(); // illegally advancing one of the sub-iterators outside of the conjunction iterator |
| AssertionError ex = expectThrows(AssertionError.class, () -> conjunction.nextDoc()); |
| assertEquals("Sub-iterators of ConjunctionDISI are not on the same document!", ex.getMessage()); |
| } |
| } |