| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.List; |
| |
| import org.apache.lucene.search.spans.Spans; |
| import org.apache.lucene.util.ArrayUtil; |
| import org.apache.lucene.util.BitSet; |
| import org.apache.lucene.util.BitSetIterator; |
| import org.apache.lucene.util.CollectionUtil; |
| |
| /** A conjunction of DocIdSetIterators. |
| * Requires that all of its sub-iterators must be on the same document all the time. |
| * This iterates over the doc ids that are present in each given DocIdSetIterator. |
| * <br>Public only for use in {@link org.apache.lucene.search.spans}. |
| * @lucene.internal |
| */ |
| public final class ConjunctionDISI extends DocIdSetIterator { |
| |
| /** Create a conjunction over the provided {@link Scorer}s. Note that the |
| * returned {@link DocIdSetIterator} might leverage two-phase iteration in |
| * which case it is possible to retrieve the {@link TwoPhaseIterator} using |
| * {@link TwoPhaseIterator#unwrap}. */ |
| public static DocIdSetIterator intersectScorers(Collection<Scorer> scorers) { |
| if (scorers.size() < 2) { |
| throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); |
| } |
| final List<DocIdSetIterator> allIterators = new ArrayList<>(); |
| final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>(); |
| for (Scorer scorer : scorers) { |
| addScorer(scorer, allIterators, twoPhaseIterators); |
| } |
| |
| return createConjunction(allIterators, twoPhaseIterators); |
| } |
| |
| /** Create a conjunction over the provided DocIdSetIterators. Note that the |
| * returned {@link DocIdSetIterator} might leverage two-phase iteration in |
| * which case it is possible to retrieve the {@link TwoPhaseIterator} using |
| * {@link TwoPhaseIterator#unwrap}. */ |
| public static DocIdSetIterator intersectIterators(List<DocIdSetIterator> iterators) { |
| if (iterators.size() < 2) { |
| throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); |
| } |
| final List<DocIdSetIterator> allIterators = new ArrayList<>(); |
| final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>(); |
| for (DocIdSetIterator iterator : iterators) { |
| addIterator(iterator, allIterators, twoPhaseIterators); |
| } |
| |
| return createConjunction(allIterators, twoPhaseIterators); |
| } |
| |
| /** Create a conjunction over the provided {@link Spans}. Note that the |
| * returned {@link DocIdSetIterator} might leverage two-phase iteration in |
| * which case it is possible to retrieve the {@link TwoPhaseIterator} using |
| * {@link TwoPhaseIterator#unwrap}. */ |
| public static DocIdSetIterator intersectSpans(List<Spans> spanList) { |
| if (spanList.size() < 2) { |
| throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); |
| } |
| final List<DocIdSetIterator> allIterators = new ArrayList<>(); |
| final List<TwoPhaseIterator> twoPhaseIterators = new ArrayList<>(); |
| for (Spans spans : spanList) { |
| addSpans(spans, allIterators, twoPhaseIterators); |
| } |
| |
| return createConjunction(allIterators, twoPhaseIterators); |
| } |
| |
| /** Adds the scorer, possibly splitting up into two phases or collapsing if it is another conjunction */ |
| private static void addScorer(Scorer scorer, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) { |
| TwoPhaseIterator twoPhaseIter = scorer.twoPhaseIterator(); |
| if (twoPhaseIter != null) { |
| addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators); |
| } else { // no approximation support, use the iterator as-is |
| addIterator(scorer.iterator(), allIterators, twoPhaseIterators); |
| } |
| } |
| |
| /** Adds the Spans. */ |
| private static void addSpans(Spans spans, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) { |
| TwoPhaseIterator twoPhaseIter = spans.asTwoPhaseIterator(); |
| if (twoPhaseIter != null) { |
| addTwoPhaseIterator(twoPhaseIter, allIterators, twoPhaseIterators); |
| } else { // no approximation support, use the iterator as-is |
| addIterator(spans, allIterators, twoPhaseIterators); |
| } |
| } |
| |
| private static void addIterator(DocIdSetIterator disi, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) { |
| TwoPhaseIterator twoPhase = TwoPhaseIterator.unwrap(disi); |
| if (twoPhase != null) { |
| addTwoPhaseIterator(twoPhase, allIterators, twoPhaseIterators); |
| } else if (disi.getClass() == ConjunctionDISI.class) { // Check for exactly this class for collapsing |
| ConjunctionDISI conjunction = (ConjunctionDISI) disi; |
| // subconjuctions have already split themselves into two phase iterators and others, so we can take those |
| // iterators as they are and move them up to this conjunction |
| allIterators.add(conjunction.lead1); |
| allIterators.add(conjunction.lead2); |
| Collections.addAll(allIterators, conjunction.others); |
| } else if (disi.getClass() == BitSetConjunctionDISI.class) { |
| BitSetConjunctionDISI conjunction = (BitSetConjunctionDISI) disi; |
| allIterators.add(conjunction.lead); |
| Collections.addAll(allIterators, conjunction.bitSetIterators); |
| } else { |
| allIterators.add(disi); |
| } |
| } |
| |
| private static void addTwoPhaseIterator(TwoPhaseIterator twoPhaseIter, List<DocIdSetIterator> allIterators, List<TwoPhaseIterator> twoPhaseIterators) { |
| addIterator(twoPhaseIter.approximation(), allIterators, twoPhaseIterators); |
| if (twoPhaseIter.getClass() == ConjunctionTwoPhaseIterator.class) { // Check for exactly this class for collapsing |
| Collections.addAll(twoPhaseIterators, ((ConjunctionTwoPhaseIterator) twoPhaseIter).twoPhaseIterators); |
| } else { |
| twoPhaseIterators.add(twoPhaseIter); |
| } |
| } |
| |
| private static DocIdSetIterator createConjunction( |
| List<DocIdSetIterator> allIterators, |
| List<TwoPhaseIterator> twoPhaseIterators) { |
| |
| // check that all sub-iterators are on the same doc ID |
| int curDoc = allIterators.size() > 0 ? allIterators.get(0).docID() : twoPhaseIterators.get(0).approximation.docID(); |
| boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc); |
| iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc); |
| if (iteratorsOnTheSameDoc == false) { |
| throw new IllegalArgumentException("Sub-iterators of ConjunctionDISI are not on the same document!"); |
| } |
| |
| long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong(); |
| List<BitSetIterator> bitSetIterators = new ArrayList<>(); |
| List<DocIdSetIterator> iterators = new ArrayList<>(); |
| for (DocIdSetIterator iterator : allIterators) { |
| if (iterator.cost() > minCost && iterator instanceof BitSetIterator) { |
| // we put all bitset iterators into bitSetIterators |
| // except if they have the minimum cost, since we need |
| // them to lead the iteration in that case |
| bitSetIterators.add((BitSetIterator) iterator); |
| } else { |
| iterators.add(iterator); |
| } |
| } |
| |
| DocIdSetIterator disi; |
| if (iterators.size() == 1) { |
| disi = iterators.get(0); |
| } else { |
| disi = new ConjunctionDISI(iterators); |
| } |
| |
| if (bitSetIterators.size() > 0) { |
| disi = new BitSetConjunctionDISI(disi, bitSetIterators); |
| } |
| |
| if (twoPhaseIterators.isEmpty() == false) { |
| disi = TwoPhaseIterator.asDocIdSetIterator(new ConjunctionTwoPhaseIterator(disi, twoPhaseIterators)); |
| } |
| |
| return disi; |
| } |
| |
| final DocIdSetIterator lead1, lead2; |
| final DocIdSetIterator[] others; |
| |
| private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) { |
| assert iterators.size() >= 2; |
| |
| // Sort the array the first time to allow the least frequent DocsEnum to |
| // lead the matching. |
| CollectionUtil.timSort(iterators, new Comparator<DocIdSetIterator>() { |
| @Override |
| public int compare(DocIdSetIterator o1, DocIdSetIterator o2) { |
| return Long.compare(o1.cost(), o2.cost()); |
| } |
| }); |
| lead1 = iterators.get(0); |
| lead2 = iterators.get(1); |
| others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]); |
| } |
| |
| private int doNext(int doc) throws IOException { |
| advanceHead: for(;;) { |
| assert doc == lead1.docID(); |
| |
| // find agreement between the two iterators with the lower costs |
| // we special case them because they do not need the |
| // 'other.docID() < doc' check that the 'others' iterators need |
| final int next2 = lead2.advance(doc); |
| if (next2 != doc) { |
| doc = lead1.advance(next2); |
| if (next2 != doc) { |
| continue; |
| } |
| } |
| |
| // then find agreement with other iterators |
| for (DocIdSetIterator other : others) { |
| // other.doc may already be equal to doc if we "continued advanceHead" |
| // on the previous iteration and the advance on the lead scorer exactly matched. |
| if (other.docID() < doc) { |
| final int next = other.advance(doc); |
| |
| if (next > doc) { |
| // iterator beyond the current doc - advance lead and continue to the new highest doc. |
| doc = lead1.advance(next); |
| continue advanceHead; |
| } |
| } |
| } |
| |
| // success - all iterators are on the same doc |
| return doc; |
| } |
| } |
| |
| @Override |
| public int advance(int target) throws IOException { |
| assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not one the same document!"; |
| return doNext(lead1.advance(target)); |
| } |
| |
| @Override |
| public int docID() { |
| return lead1.docID(); |
| } |
| |
| @Override |
| public int nextDoc() throws IOException { |
| assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; |
| return doNext(lead1.nextDoc()); |
| } |
| |
| @Override |
| public long cost() { |
| return lead1.cost(); // overestimate |
| } |
| |
| // Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise |
| private boolean assertItersOnSameDoc() { |
| int curDoc = lead1.docID(); |
| boolean iteratorsOnTheSameDoc = (lead2.docID() == curDoc); |
| for (int i = 0; (i < others.length && iteratorsOnTheSameDoc); i++) { |
| iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (others[i].docID() == curDoc); |
| } |
| return iteratorsOnTheSameDoc; |
| } |
| |
| /** Conjunction between a {@link DocIdSetIterator} and one or more {@link BitSetIterator}s. */ |
| private static class BitSetConjunctionDISI extends DocIdSetIterator { |
| |
| private final DocIdSetIterator lead; |
| private final BitSetIterator[] bitSetIterators; |
| private final BitSet[] bitSets; |
| private final int minLength; |
| |
| BitSetConjunctionDISI(DocIdSetIterator lead, Collection<BitSetIterator> bitSetIterators) { |
| this.lead = lead; |
| assert bitSetIterators.size() > 0; |
| |
| this.bitSetIterators = bitSetIterators.toArray(new BitSetIterator[0]); |
| // Put the least costly iterators first so that we exit as soon as possible |
| ArrayUtil.timSort(this.bitSetIterators, (a, b) -> Long.compare(a.cost(), b.cost())); |
| this.bitSets = new BitSet[this.bitSetIterators.length]; |
| int minLen = Integer.MAX_VALUE; |
| for (int i = 0; i < this.bitSetIterators.length; ++i) { |
| BitSet bitSet = this.bitSetIterators[i].getBitSet(); |
| this.bitSets[i] = bitSet; |
| minLen = Math.min(minLen, bitSet.length()); |
| } |
| this.minLength = minLen; |
| } |
| |
| @Override |
| public int docID() { |
| return lead.docID(); |
| } |
| |
| @Override |
| public int nextDoc() throws IOException { |
| assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; |
| return doNext(lead.nextDoc()); |
| } |
| |
| @Override |
| public int advance(int target) throws IOException { |
| assert assertItersOnSameDoc() : "Sub-iterators of ConjunctionDISI are not on the same document!"; |
| return doNext(lead.advance(target)); |
| } |
| |
| private int doNext(int doc) throws IOException { |
| advanceLead: for (;; doc = lead.nextDoc()) { |
| if (doc >= minLength) { |
| return NO_MORE_DOCS; |
| } |
| for (BitSet bitSet : bitSets) { |
| if (bitSet.get(doc) == false) { |
| continue advanceLead; |
| } |
| } |
| for (BitSetIterator iterator : bitSetIterators) { |
| iterator.setDocId(doc); |
| } |
| return doc; |
| } |
| } |
| |
| @Override |
| public long cost() { |
| return lead.cost(); |
| } |
| |
| // Returns {@code true} if all sub-iterators are on the same doc ID, {@code false} otherwise |
| private boolean assertItersOnSameDoc() { |
| int curDoc = lead.docID(); |
| boolean iteratorsOnTheSameDoc = true; |
| for (int i = 0; (i < bitSetIterators.length && iteratorsOnTheSameDoc); i++) { |
| iteratorsOnTheSameDoc = iteratorsOnTheSameDoc && (bitSetIterators[i].docID() == curDoc); |
| } |
| return iteratorsOnTheSameDoc; |
| } |
| |
| } |
| |
| /** |
| * {@link TwoPhaseIterator} implementing a conjunction. |
| */ |
| private static final class ConjunctionTwoPhaseIterator extends TwoPhaseIterator { |
| |
| private final TwoPhaseIterator[] twoPhaseIterators; |
| private final float matchCost; |
| |
| private ConjunctionTwoPhaseIterator(DocIdSetIterator approximation, |
| List<? extends TwoPhaseIterator> twoPhaseIterators) { |
| super(approximation); |
| assert twoPhaseIterators.size() > 0; |
| |
| CollectionUtil.timSort(twoPhaseIterators, new Comparator<TwoPhaseIterator>() { |
| @Override |
| public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) { |
| return Float.compare(o1.matchCost(), o2.matchCost()); |
| } |
| }); |
| |
| this.twoPhaseIterators = twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]); |
| |
| // Compute the matchCost as the total matchCost of the sub iterators. |
| // TODO: This could be too high because the matching is done cheapest first: give the lower matchCosts a higher weight. |
| float totalMatchCost = 0; |
| for (TwoPhaseIterator tpi : twoPhaseIterators) { |
| totalMatchCost += tpi.matchCost(); |
| } |
| matchCost = totalMatchCost; |
| } |
| |
| @Override |
| public boolean matches() throws IOException { |
| for (TwoPhaseIterator twoPhaseIterator : twoPhaseIterators) { // match cheapest first |
| if (twoPhaseIterator.matches() == false) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| @Override |
| public float matchCost() { |
| return matchCost; |
| } |
| |
| } |
| |
| } |