blob: ce0022d5196e1c54e3c4590e44e1b6183d796619 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
/**
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
* search algorithm, see {@link HnswGraph}.
*
* @param <T> the type of query vector
*/
public class HnswGraphSearcher<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
/**
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
* to allocate, so they're cleared and reused across calls.
*/
private final NeighborQueue candidates;
private BitSet visited;
/**
* Creates a new graph searcher.
*
* @param similarityFunction the similarity function to compare vectors
* @param candidates max heap that will track the candidate nodes to explore
* @param visited bit set that will track nodes that have already been visited
*/
public HnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction;
this.candidates = candidates;
this.visited = visited;
}
/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
float[] query,
int topK,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
if (vectorEncoding == VectorEncoding.BYTE) {
return search(
toBytesRef(query),
topK,
vectors,
vectorEncoding,
similarityFunction,
graph,
acceptOrds,
visitedLimit);
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
private static NeighborQueue search(
BytesRef query,
int topK,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
HnswGraphSearcher<BytesRef> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
/**
* Searches for the nearest neighbors of a query vector in a given level.
*
* <p>If the search stops early because it reaches the visited nodes limit, then the results will
* be marked incomplete through {@link NeighborQueue#incomplete()}.
*
* @param query search query vector
* @param topK the number of nearest to query results to return
* @param level level to search
* @param eps the entry points for search at this level expressed as level 0th ordinals
* @param vectors vector values
* @param graph the graph values
* @return a priority queue holding the closest neighbors found
*/
public NeighborQueue searchLevel(
// Note: this is only public because Lucene91HnswGraphBuilder needs it
T query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues vectors,
HnswGraph graph)
throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
}
private NeighborQueue searchLevel(
T query,
int topK,
int level,
final int[] eps,
RandomAccessVectorValues vectors,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, false);
prepareScratchState(vectors.size());
int numVisited = 0;
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = compare(query, vectors, ep);
numVisited++;
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.add(ep, score);
}
}
}
// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
float topCandidateSimilarity = candidates.topScore();
if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
}
int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
int friendOrd;
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
}
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float friendSimilarity = compare(query, vectors, friendOrd);
numVisited++;
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}
}
}
}
}
while (results.size() > topK) {
results.pop();
}
results.setVisitedCount(numVisited);
return results;
}
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
} else {
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
}
}
private void prepareScratchState(int capacity) {
candidates.clear();
if (visited.length() < capacity) {
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
}
visited.clear(0, visited.length());
}
}