blob: c9b8f5258eea8526b252eacd6b1f4f12f031b89e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
*
* <ul>
* <li>If the filter cost is less than k, just execute an exact search
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
public class KnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final String field;
private final float[] target;
private final int k;
private final Query filter;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k) {
this(field, target, k, null);
}
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
this.field = field;
this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
this.filter = filter;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
BitSetCollector filterCollector = null;
if (filter != null) {
filterCollector = new BitSetCollector(reader.leaves().size());
IndexSearcher indexSearcher = new IndexSearcher(reader);
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new KnnVectorFieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
indexSearcher.search(booleanQuery, filterCollector);
}
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterCollector);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
throws IOException {
if (filterCollector == null) {
Bits acceptDocs = ctx.reader().getLiveDocs();
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
} else {
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
if (filterIterator == null || filterIterator.cost() == 0) {
return NO_RESULTS;
}
if (filterIterator.cost() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
}
// Perform the approximate kNN search
Bits acceptDocs =
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
int visitedLimit = (int) filterIterator.cost();
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
}
}
}
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
}
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
VectorValues vectorValues = context.reader().getVectorValues(field);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.convertToScore(similarityFunction.compare(vector, target));
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
private static class BitSetCollector extends SimpleCollector {
private final BitSet[] bitSets;
private final int[] cost;
private int ord;
private BitSetCollector(int numLeaves) {
this.bitSets = new BitSet[numLeaves];
this.cost = new int[bitSets.length];
}
/**
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
* null.
*/
public BitSetIterator getIterator(int contextOrd) {
if (bitSets[contextOrd] == null) {
return null;
}
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
}
@Override
public void collect(int doc) throws IOException {
bitSets[ord].set(doc);
cost[ord]++;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
ord = context.ord;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}
@Override
public String toString(String field) {
return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]";
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object obj) {
return sameClassAs(obj)
&& ((KnnVectorQuery) obj).k == k
&& ((KnnVectorQuery) obj).field.equals(field)
&& Arrays.equals(((KnnVectorQuery) obj).target, target);
}
@Override
public int hashCode() {
return Objects.hash(classHash(), field, k, Arrays.hashCode(target));
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;
/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found], "within top " + k);
}
@Override
public Scorer scorer(LeafReaderContext context) {
return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}
@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return upper - lower;
}
};
}
@Override
public float getMaxScore(int docid) {
docid += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore;
}
@Override
public float score() {
return scores[upTo];
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
return docs[upTo] - context.docBase;
}
@Override
public int docID() {
return docIdNoShadow();
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}
@Override
public int hashCode() {
return Objects.hash(
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
}