| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.List; |
| import java.util.stream.LongStream; |
| import java.util.stream.StreamSupport; |
| |
| import org.apache.lucene.util.PriorityQueue; |
| |
| import static org.apache.lucene.search.DisiPriorityQueue.leftNode; |
| import static org.apache.lucene.search.DisiPriorityQueue.parentNode; |
| import static org.apache.lucene.search.DisiPriorityQueue.rightNode; |
| |
| /** |
| * A {@link Scorer} for {@link BooleanQuery} when |
| * {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int) minShouldMatch} is |
| * between 2 and the total number of clauses. |
| * |
| * This implementation keeps sub scorers in 3 different places: |
| * - lead: a linked list of scorer that are positioned on the desired doc ID |
| * - tail: a heap that contains at most minShouldMatch - 1 scorers that are |
| * behind the desired doc ID. These scorers are ordered by cost so that we |
| * can advance the least costly ones first. |
| * - head: a heap that contains scorers which are beyond the desired doc ID, |
| * ordered by doc ID in order to move quickly to the next candidate. |
| * |
| * Finding the next match consists of first setting the desired doc ID to the |
| * least entry in 'head' and then advance 'tail' until there is a match. |
| */ |
| final class MinShouldMatchSumScorer extends Scorer { |
| |
| static long cost(LongStream costs, int numScorers, int minShouldMatch) { |
| // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m |
| // could be rewritten to: |
| // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m)) |
| // if we assume that clauses come in ascending cost, then |
| // the cost of the first part is the cost of c1 (because the cost of a conjunction is |
| // the cost of the least costly clause) |
| // the cost of the second part is the cost of finding m matches among the c2...cn |
| // remaining clauses |
| // since it is a disjunction overall, the total cost is the sum of the costs of these |
| // two parts |
| |
| // If we recurse infinitely, we find out that the cost of a msm query is the sum of the |
| // costs of the num_scorers - minShouldMatch + 1 least costly scorers |
| final PriorityQueue<Long> pq = new PriorityQueue<Long>(numScorers - minShouldMatch + 1) { |
| @Override |
| protected boolean lessThan(Long a, Long b) { |
| return a > b; |
| } |
| }; |
| costs.forEach(pq::insertWithOverflow); |
| return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum(); |
| } |
| |
| final int minShouldMatch; |
| |
| // list of scorers which 'lead' the iteration and are currently |
| // positioned on 'doc' |
| DisiWrapper lead; |
| int doc; // current doc ID of the leads |
| int freq; // number of scorers on the desired doc ID |
| |
| // priority queue of scorers that are too advanced compared to the current |
| // doc. Ordered by doc ID. |
| final DisiPriorityQueue head; |
| |
| // priority queue of scorers which are behind the current doc. |
| // Ordered by cost. |
| final DisiWrapper[] tail; |
| int tailSize; |
| |
| final long cost; |
| |
| MinShouldMatchSumScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch) { |
| super(weight); |
| |
| if (minShouldMatch > scorers.size()) { |
| throw new IllegalArgumentException("minShouldMatch should be <= the number of scorers"); |
| } |
| if (minShouldMatch < 1) { |
| throw new IllegalArgumentException("minShouldMatch should be >= 1"); |
| } |
| |
| this.minShouldMatch = minShouldMatch; |
| this.doc = -1; |
| |
| head = new DisiPriorityQueue(scorers.size() - minShouldMatch + 1); |
| // there can be at most minShouldMatch - 1 scorers beyond the current position |
| // otherwise we might be skipping over matching documents |
| tail = new DisiWrapper[minShouldMatch - 1]; |
| |
| for (Scorer scorer : scorers) { |
| addLead(new DisiWrapper(scorer)); |
| } |
| |
| this.cost = cost(scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost), scorers.size(), minShouldMatch); |
| } |
| |
| @Override |
| public final Collection<ChildScorable> getChildren() throws IOException { |
| List<ChildScorable> matchingChildren = new ArrayList<>(); |
| updateFreq(); |
| for (DisiWrapper s = lead; s != null; s = s.next) { |
| matchingChildren.add(new ChildScorable(s.scorer, "SHOULD")); |
| } |
| return matchingChildren; |
| } |
| |
| @Override |
| public DocIdSetIterator iterator() { |
| return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); |
| } |
| |
| @Override |
| public TwoPhaseIterator twoPhaseIterator() { |
| DocIdSetIterator approximation = new DocIdSetIterator() { |
| |
| @Override |
| public int docID() { |
| assert doc == lead.doc; |
| return doc; |
| } |
| |
| @Override |
| public int nextDoc() throws IOException { |
| // We are moving to the next doc ID, so scorers in 'lead' need to go in |
| // 'tail'. If there is not enough space in 'tail', then we take the least |
| // costly scorers and advance them. |
| for (DisiWrapper s = lead; s != null; s = s.next) { |
| final DisiWrapper evicted = insertTailWithOverFlow(s); |
| if (evicted != null) { |
| if (evicted.doc == doc) { |
| evicted.doc = evicted.iterator.nextDoc(); |
| } else { |
| evicted.doc = evicted.iterator.advance(doc + 1); |
| } |
| head.add(evicted); |
| } |
| } |
| |
| setDocAndFreq(); |
| // It would be correct to return doNextCandidate() at this point but if you |
| // call nextDoc as opposed to advance, it probably means that you really |
| // need the next match. Returning 'doc' here would lead to a similar |
| // iteration over sub postings overall except that the decision making would |
| // happen at a higher level where more abstractions are involved and |
| // benchmarks suggested it causes a significant performance hit. |
| return doNext(); |
| } |
| |
| @Override |
| public int advance(int target) throws IOException { |
| // Same logic as in nextDoc |
| for (DisiWrapper s = lead; s != null; s = s.next) { |
| final DisiWrapper evicted = insertTailWithOverFlow(s); |
| if (evicted != null) { |
| evicted.doc = evicted.iterator.advance(target); |
| head.add(evicted); |
| } |
| } |
| |
| // But this time there might also be scorers in 'head' behind the desired |
| // target so we need to do the same thing that we did on 'lead' on 'head' |
| DisiWrapper headTop = head.top(); |
| while (headTop.doc < target) { |
| final DisiWrapper evicted = insertTailWithOverFlow(headTop); |
| // We know that the tail is full since it contains at most |
| // minShouldMatch - 1 entries and we just moved at least minShouldMatch |
| // entries to it, so evicted is not null |
| evicted.doc = evicted.iterator.advance(target); |
| headTop = head.updateTop(evicted); |
| } |
| |
| setDocAndFreq(); |
| return doNextCandidate(); |
| } |
| |
| @Override |
| public long cost() { |
| return cost; |
| } |
| }; |
| return new TwoPhaseIterator(approximation) { |
| |
| @Override |
| public boolean matches() throws IOException { |
| while (freq < minShouldMatch) { |
| assert freq > 0; |
| if (freq + tailSize >= minShouldMatch) { |
| // a match on doc is still possible, try to |
| // advance scorers from the tail |
| advanceTail(); |
| } else { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| @Override |
| public float matchCost() { |
| // maximum number of scorer that matches() might advance |
| return tail.length; |
| } |
| |
| }; |
| } |
| |
| private void addLead(DisiWrapper lead) { |
| lead.next = this.lead; |
| this.lead = lead; |
| freq += 1; |
| } |
| |
| private void pushBackLeads() throws IOException { |
| for (DisiWrapper s = lead; s != null; s = s.next) { |
| addTail(s); |
| } |
| } |
| |
| private void advanceTail(DisiWrapper top) throws IOException { |
| top.doc = top.iterator.advance(doc); |
| if (top.doc == doc) { |
| addLead(top); |
| } else { |
| head.add(top); |
| } |
| } |
| |
| private void advanceTail() throws IOException { |
| final DisiWrapper top = popTail(); |
| advanceTail(top); |
| } |
| |
| /** Reinitializes head, freq and doc from 'head' */ |
| private void setDocAndFreq() { |
| assert head.size() > 0; |
| |
| // The top of `head` defines the next potential match |
| // pop all documents which are on this doc |
| lead = head.pop(); |
| lead.next = null; |
| freq = 1; |
| doc = lead.doc; |
| while (head.size() > 0 && head.top().doc == doc) { |
| addLead(head.pop()); |
| } |
| } |
| |
| /** Advance tail to the lead until there is a match. */ |
| private int doNext() throws IOException { |
| while (freq < minShouldMatch) { |
| assert freq > 0; |
| if (freq + tailSize >= minShouldMatch) { |
| // a match on doc is still possible, try to |
| // advance scorers from the tail |
| advanceTail(); |
| } else { |
| // no match on doc is possible anymore, move to the next potential match |
| pushBackLeads(); |
| setDocAndFreq(); |
| } |
| } |
| |
| return doc; |
| } |
| |
| /** Move iterators to the tail until the cumulated size of lead+tail is |
| * greater than or equal to minShouldMath */ |
| private int doNextCandidate() throws IOException { |
| while (freq + tailSize < minShouldMatch) { |
| // no match on doc is possible, move to the next potential match |
| pushBackLeads(); |
| setDocAndFreq(); |
| } |
| |
| return doc; |
| } |
| |
| /** Advance all entries from the tail to know about all matches on the |
| * current doc. */ |
| private void updateFreq() throws IOException { |
| assert freq >= minShouldMatch; |
| // we return the next doc when there are minShouldMatch matching clauses |
| // but some of the clauses in 'tail' might match as well |
| // in general we want to advance least-costly clauses first in order to |
| // skip over non-matching documents as fast as possible. However here, |
| // we are advancing everything anyway so iterating over clauses in |
| // (roughly) cost-descending order might help avoid some permutations in |
| // the head heap |
| for (int i = tailSize - 1; i >= 0; --i) { |
| advanceTail(tail[i]); |
| } |
| tailSize = 0; |
| } |
| |
| @Override |
| public float score() throws IOException { |
| // we need to know about all matches |
| updateFreq(); |
| double score = 0; |
| for (DisiWrapper s = lead; s != null; s = s.next) { |
| score += s.scorer.score(); |
| } |
| return (float) score; |
| } |
| |
| @Override |
| public float getMaxScore(int upTo) throws IOException { |
| // TODO: implement but be careful about floating-point errors. |
| return Float.POSITIVE_INFINITY; |
| } |
| |
| @Override |
| public int docID() { |
| assert doc == lead.doc; |
| return doc; |
| } |
| |
| /** Insert an entry in 'tail' and evict the least-costly scorer if full. */ |
| private DisiWrapper insertTailWithOverFlow(DisiWrapper s) { |
| if (tailSize < tail.length) { |
| addTail(s); |
| return null; |
| } else if (tail.length >= 1) { |
| final DisiWrapper top = tail[0]; |
| if (top.cost < s.cost) { |
| tail[0] = s; |
| downHeapCost(tail, tailSize); |
| return top; |
| } |
| } |
| return s; |
| } |
| |
| /** Add an entry to 'tail'. Fails if over capacity. */ |
| private void addTail(DisiWrapper s) { |
| tail[tailSize] = s; |
| upHeapCost(tail, tailSize); |
| tailSize += 1; |
| } |
| |
| /** Pop the least-costly scorer from 'tail'. */ |
| private DisiWrapper popTail() { |
| assert tailSize > 0; |
| final DisiWrapper result = tail[0]; |
| tail[0] = tail[--tailSize]; |
| downHeapCost(tail, tailSize); |
| return result; |
| } |
| |
| /** Heap helpers */ |
| |
| private static void upHeapCost(DisiWrapper[] heap, int i) { |
| final DisiWrapper node = heap[i]; |
| final long nodeCost = node.cost; |
| int j = parentNode(i); |
| while (j >= 0 && nodeCost < heap[j].cost) { |
| heap[i] = heap[j]; |
| i = j; |
| j = parentNode(j); |
| } |
| heap[i] = node; |
| } |
| |
| private static void downHeapCost(DisiWrapper[] heap, int size) { |
| int i = 0; |
| final DisiWrapper node = heap[0]; |
| int j = leftNode(i); |
| if (j < size) { |
| int k = rightNode(j); |
| if (k < size && heap[k].cost < heap[j].cost) { |
| j = k; |
| } |
| if (heap[j].cost < node.cost) { |
| do { |
| heap[i] = heap[j]; |
| i = j; |
| j = leftNode(i); |
| k = rightNode(j); |
| if (k < size && heap[k].cost < heap[j].cost) { |
| j = k; |
| } |
| } while (j < size && heap[j].cost < node.cost); |
| heap[i] = node; |
| } |
| } |
| } |
| |
| } |