blob: a93bed2cde67f08d82aba15ab0d31494146ea451 [file] [log] [blame]
package org.apache.lucene.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.PriorityQueue;
import static org.apache.lucene.search.DisiPriorityQueue.leftNode;
import static org.apache.lucene.search.DisiPriorityQueue.parentNode;
import static org.apache.lucene.search.DisiPriorityQueue.rightNode;
/**
* A {@link Scorer} for {@link BooleanQuery} when
* {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int) minShouldMatch} is
* between 2 and the total number of clauses.
*
* This implementation keeps sub scorers in 3 different places:
* - lead: a linked list of scorer that are positioned on the desired doc ID
* - tail: a heap that contains at most minShouldMatch - 1 scorers that are
* behind the desired doc ID. These scorers are ordered by cost so that we
* can advance the least costly ones first.
* - head: a heap that contains scorers which are beyond the desired doc ID,
* ordered by doc ID in order to move quickly to the next candidate.
*
* Finding the next match consists of first setting the desired doc ID to the
* least entry in 'head' and then advance 'tail' until there is a match.
*/
final class MinShouldMatchSumScorer extends Scorer {
private static long cost(Collection<Scorer> scorers, int minShouldMatch) {
// the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
// could be rewritten to:
// (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
// if we assume that clauses come in ascending cost, then
// the cost of the first part is the cost of c1 (because the cost of a conjunction is
// the cost of the least costly clause)
// the cost of the second part is the cost of finding m matches among the c2...cn
// remaining clauses
// since it is a disjunction overall, the total cost is the sum of the costs of these
// two parts
// If we recurse infinitely, we find out that the cost of a msm query is the sum of the
// costs of the num_scorers - minShouldMatch + 1 least costly scorers
final PriorityQueue<Scorer> pq = new PriorityQueue<Scorer>(scorers.size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(Scorer a, Scorer b) {
return a.cost() > b.cost();
}
};
for (Scorer scorer : scorers) {
pq.insertWithOverflow(scorer);
}
long cost = 0;
for (Scorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
cost += scorer.cost();
}
return cost;
}
final int minShouldMatch;
final float[] coord;
// list of scorers which 'lead' the iteration and are currently
// positioned on 'doc'
DisiWrapper<Scorer> lead;
int doc; // current doc ID of the leads
int freq; // number of scorers on the desired doc ID
// priority queue of scorers that are too advanced compared to the current
// doc. Ordered by doc ID.
final DisiPriorityQueue<Scorer> head;
// priority queue of scorers which are behind the current doc.
// Ordered by cost.
final DisiWrapper<Scorer>[] tail;
int tailSize;
final Collection<ChildScorer> childScorers;
final long cost;
@SuppressWarnings({"unchecked","rawtypes"})
MinShouldMatchSumScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch, float[] coord) {
super(weight);
if (minShouldMatch > scorers.size()) {
throw new IllegalArgumentException("minShouldMatch should be <= the number of scorers");
}
if (minShouldMatch < 1) {
throw new IllegalArgumentException("minShouldMatch should be >= 1");
}
this.minShouldMatch = minShouldMatch;
this.coord = coord;
this.doc = -1;
head = new DisiPriorityQueue<Scorer>(scorers.size() - minShouldMatch + 1);
// there can be at most minShouldMatch - 1 scorers beyond the current position
// otherwise we might be skipping over matching documents
tail = new DisiWrapper[minShouldMatch - 1];
for (Scorer scorer : scorers) {
addLead(new DisiWrapper<Scorer>(scorer));
}
List<ChildScorer> children = new ArrayList<>();
for (Scorer scorer : scorers) {
children.add(new ChildScorer(scorer, "SHOULD"));
}
this.childScorers = Collections.unmodifiableCollection(children);
this.cost = cost(scorers, minShouldMatch);
}
@Override
public long cost() {
return cost;
}
@Override
public final Collection<ChildScorer> getChildren() {
return childScorers;
}
@Override
public int nextDoc() throws IOException {
// We are moving to the next doc ID, so scorers in 'lead' need to go in
// 'tail'. If there is not enough space in 'tail', then we take the least
// costly scorers and advance them.
for (DisiWrapper<Scorer> s = lead; s != null; s = s.next) {
final DisiWrapper<Scorer> evicted = insertTailWithOverFlow(s);
if (evicted != null) {
if (evicted.doc == doc) {
evicted.doc = evicted.iterator.nextDoc();
} else {
evicted.doc = evicted.iterator.advance(doc + 1);
}
head.add(evicted);
}
}
setDocAndFreq();
return doNext();
}
@Override
public int advance(int target) throws IOException {
// Same logic as in nextDoc
for (DisiWrapper<Scorer> s = lead; s != null; s = s.next) {
final DisiWrapper<Scorer> evicted = insertTailWithOverFlow(s);
if (evicted != null) {
evicted.doc = evicted.iterator.advance(target);
head.add(evicted);
}
}
// But this time there might also be scorers in 'head' behind the desired
// target so we need to do the same thing that we did on 'lead' on 'head'
DisiWrapper<Scorer> headTop = head.top();
while (headTop.doc < target) {
final DisiWrapper<Scorer> evicted = insertTailWithOverFlow(headTop);
// We know that the tail is full since it contains at most
// minShouldMatch - 1 entries and we just moved at least minShouldMatch
// entries to it, so evicted is not null
evicted.doc = evicted.iterator.advance(target);
headTop = head.updateTop(evicted);
}
setDocAndFreq();
return doNext();
}
private void addLead(DisiWrapper<Scorer> lead) {
lead.next = this.lead;
this.lead = lead;
freq += 1;
}
private void pushBackLeads() throws IOException {
for (DisiWrapper<Scorer> s = lead; s != null; s = s.next) {
addTail(s);
}
}
private void advanceTail(DisiWrapper<Scorer> top) throws IOException {
top.doc = top.iterator.advance(doc);
if (top.doc == doc) {
addLead(top);
} else {
head.add(top);
}
}
private void advanceTail() throws IOException {
final DisiWrapper<Scorer> top = popTail();
advanceTail(top);
}
/** Reinitializes head, freq and doc from 'head' */
private void setDocAndFreq() {
assert head.size() > 0;
// The top of `head` defines the next potential match
// pop all documents which are on this doc
lead = head.pop();
lead.next = null;
freq = 1;
doc = lead.doc;
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
}
}
/** Advance tail to the lead until there is a match. */
private int doNext() throws IOException {
while (freq < minShouldMatch) {
assert freq > 0;
if (freq + tailSize >= minShouldMatch) {
// a match on doc is still possible, try to
// advance scorers from the tail
advanceTail();
} else {
// no match on doc is possible anymore, move to the next potential match
pushBackLeads();
setDocAndFreq();
}
}
return doc;
}
/** Advance all entries from the tail to know about all matches on the
* current doc. */
private void updateFreq() throws IOException {
assert freq >= minShouldMatch;
// we return the next doc when there are minShouldMatch matching clauses
// but some of the clauses in 'tail' might match as well
// in general we want to advance least-costly clauses first in order to
// skip over non-matching documents as fast as possible. However here,
// we are advancing everything anyway so iterating over clauses in
// (roughly) cost-descending order might help avoid some permutations in
// the head heap
for (int i = tailSize - 1; i >= 0; --i) {
advanceTail(tail[i]);
}
tailSize = 0;
}
@Override
public int freq() throws IOException {
// we need to know about all matches
updateFreq();
return freq;
}
@Override
public float score() throws IOException {
// we need to know about all matches
updateFreq();
double score = 0;
for (DisiWrapper<Scorer> s = lead; s != null; s = s.next) {
score += s.iterator.score();
}
return coord[freq] * (float) score;
}
@Override
public int docID() {
assert doc == lead.doc;
return doc;
}
/** Insert an entry in 'tail' and evict the least-costly scorer if full. */
private DisiWrapper<Scorer> insertTailWithOverFlow(DisiWrapper<Scorer> s) {
if (tailSize < tail.length) {
addTail(s);
return null;
} else if (tail.length >= 1) {
final DisiWrapper<Scorer> top = tail[0];
if (top.cost < s.cost) {
tail[0] = s;
downHeapCost(tail, tailSize);
return top;
}
}
return s;
}
/** Add an entry to 'tail'. Fails if over capacity. */
private void addTail(DisiWrapper<Scorer> s) {
tail[tailSize] = s;
upHeapCost(tail, tailSize);
tailSize += 1;
}
/** Pop the least-costly scorer from 'tail'. */
private DisiWrapper<Scorer> popTail() {
assert tailSize > 0;
final DisiWrapper<Scorer> result = tail[0];
tail[0] = tail[--tailSize];
downHeapCost(tail, tailSize);
return result;
}
/** Heap helpers */
private static void upHeapCost(DisiWrapper<Scorer>[] heap, int i) {
final DisiWrapper<Scorer> node = heap[i];
final long nodeCost = node.cost;
int j = parentNode(i);
while (j >= 0 && nodeCost < heap[j].cost) {
heap[i] = heap[j];
i = j;
j = parentNode(j);
}
heap[i] = node;
}
private static void downHeapCost(DisiWrapper<Scorer>[] heap, int size) {
int i = 0;
final DisiWrapper<Scorer> node = heap[0];
int j = leftNode(i);
if (j < size) {
int k = rightNode(j);
if (k < size && heap[k].cost < heap[j].cost) {
j = k;
}
if (heap[j].cost < node.cost) {
do {
heap[i] = heap[j];
i = j;
j = leftNode(i);
k = rightNode(j);
if (k < size && heap[k].cost < heap[j].cost) {
j = k;
}
} while (j < size && heap[j].cost < node.cost);
heap[i] = node;
}
}
}
}