* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import static;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
* <ul>
* <li>If the filter cost is less than k, just execute an exact search
* <li>Otherwise run a kNN search subject to the filter
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
public class KnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final String field;
private final float[] target;
private final int k;
private final Query filter;
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
public KnnVectorQuery(String field, float[] target, int k) {
this(field, target, k, null);
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
this.field = field; = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
this.filter = filter;
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
BitSetCollector filterCollector = null;
if (filter != null) {
filterCollector = new BitSetCollector(reader.leaves().size());
IndexSearcher indexSearcher = new IndexSearcher(reader);
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new KnnVectorFieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();, filterCollector);
for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterCollector);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
perLeafResults[ctx.ord] = results;
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
return createRewrittenQuery(reader, topK);
private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector)
throws IOException {
if (filterCollector == null) {
Bits acceptDocs = ctx.reader().getLiveDocs();
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
} else {
BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
if (filterIterator == null || filterIterator.cost() == 0) {
return NO_RESULTS;
if (filterIterator.cost() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
// Perform the approximate kNN search
Bits acceptDocs =
filterIterator.getBitSet(); // The filter iterator already incorporates live docs
int visitedLimit = (int) filterIterator.cost();
TopDocs results = approximateSearch(ctx, acceptDocs, visitedLimit);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
VectorValues vectorValues = context.reader().getVectorValues(field);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc =;
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.convertToScore(, target));
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
// Remove any remaining sentinel values
while (queue.size() > 0 && < 0) {
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
private static class BitSetCollector extends SimpleCollector {
private final BitSet[] bitSets;
private final int[] cost;
private int ord;
private BitSetCollector(int numLeaves) {
this.bitSets = new BitSet[numLeaves];
this.cost = new int[bitSets.length];
* Return an iterator whose {@link BitSet} contains the matching documents, and whose {@link
* BitSetIterator#cost()} is the exact cardinality. If the leaf was never visited, then return
* null.
public BitSetIterator getIterator(int contextOrd) {
if (bitSets[contextOrd] == null) {
return null;
return new BitSetIterator(bitSets[contextOrd], cost[contextOrd]);
public void collect(int doc) throws IOException {
protected void doSetNextReader(LeafReaderContext context) throws IOException {
bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
ord = context.ord;
public scoreMode() {
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
starts[i] = resultIndex;
return starts;
public String toString(String field) {
return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]";
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
public boolean equals(Object obj) {
return sameClassAs(obj)
&& ((KnnVectorQuery) obj).k == k
&& ((KnnVectorQuery) obj).field.equals(field)
&& Arrays.equals(((KnnVectorQuery) obj).target, target);
public int hashCode() {
return Objects.hash(classHash(), field, k, Arrays.hashCode(target));
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;
* Constructor
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param contextIdentity an object identifying the reader context that was used to build this
* query
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k; = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
return new Weight(this) {
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
return Explanation.match(scores[found], "within top " + k);
public Scorer scorer(LeafReaderContext context) {
return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
public int docID() {
return docIdNoShadow();
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
return docIdNoShadow();
public int advance(int target) throws IOException {
return slowAdvance(target);
public long cost() {
return upper - lower;
public float getMaxScore(int docid) {
docid += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
return maxScore;
public float score() {
return scores[upTo];
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
if (docidIndex >= upper) {
return NO_MORE_DOCS;
return docs[docidIndex];
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
* @return the current docid
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
if (upTo >= upper) {
return NO_MORE_DOCS;
return docs[upTo] - context.docBase;
public int docID() {
return docIdNoShadow();
public boolean isCacheable(LeafReaderContext ctx) {
return true;
public String toString(String field) {
return "DocAndScore[" + k + "]";
public void visit(QueryVisitor visitor) {
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
public int hashCode() {
return Objects.hash(
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));