Use vector bulk scoring in entry-point and filter hnsw search (#15500)
We currently use the bulk scoring interface on the lowest HNSW level, but there isn't any technical reason why we cannot use bulk scoring on higher levels to find the entry point. For highly connected graphs with many levels, this could provide a small speed improvement.
Additionally, I noticed that we were not doing bulk scoring when using the Filtered hnsw searcher. Again, we should be using bulk scoring even when using the filter optimized HNSW scorer.
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 1a2ec6a..660decf 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -260,6 +260,9 @@
* GITHUB#15474: Use bulk scoring provided by RandomVectorScorers for new scalar quantized formats provided through
Lucene104ScalarQuantizedVectorsFormat and Lucene104HnswScalarQuantizedVectorsFormat (Ben Trent)
+* GITHUB#15500: Use bulk scoring for filtered HNSW search and for entry-point scoring in the graph. This should
+ provide speed improvements when using vector scorers that satisfy the bulk scoring interface. (Ben Trent)
+
Bug Fixes
---------------------
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java
index ab602e4..2f668d9 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/AbstractHnswGraphSearcher.java
@@ -18,6 +18,7 @@
import java.io.IOException;
import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
/**
@@ -81,4 +82,28 @@
}
searchLevel(results, scorer, 0, eps, graph, acceptOrds);
}
+
+ protected static void scoreEntryPoints(
+ KnnCollector results,
+ RandomVectorScorer scorer,
+ BitSet visited,
+ int[] eps,
+ Bits acceptOrds,
+ NeighborQueue candidates,
+ float[] scores)
+ throws IOException {
+ assert eps != null && eps.length > 0;
+ assert scores != null && scores.length >= eps.length;
+ scorer.bulkScore(eps, scores, eps.length);
+ results.incVisitedCount(eps.length);
+ for (int i = 0; i < eps.length; i++) {
+ float score = scores[i];
+ int ep = eps[i];
+ visited.set(ep);
+ candidates.add(ep, score);
+ if (acceptOrds == null || acceptOrds.get(ep)) {
+ results.collect(ep, score);
+ }
+ }
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/FilteredHnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/FilteredHnswGraphSearcher.java
index 10e5725..a6a2814 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/FilteredHnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/FilteredHnswGraphSearcher.java
@@ -28,8 +28,9 @@
/**
* Searches an HNSW graph to find nearest neighbors to a query vector. This particular
- * implementation is optimized for a filtered search, inspired by the ACORN-1 algorithm.
- * https://arxiv.org/abs/2403.04871 However, this implementation is augmented in some ways, mainly:
+ * implementation is optimized for a filtered search, inspired by the <a
+ * href="https://arxiv.org/abs/2403.04871">ACORN-1 algorithm</a>. However, this implementation is
+ * augmented in some ways, mainly:
*
* <ul>
* <li>It dynamically determines when the optimized filter step should occur based on some
@@ -114,18 +115,15 @@
prepareScratchState();
- for (int ep : eps) {
- if (visited.getAndSet(ep) == false) {
- if (results.earlyTerminated()) {
- return;
- }
- float score = scorer.score(ep);
- results.incVisitedCount(1);
- candidates.add(ep, score);
- if (acceptOrds.get(ep)) {
- results.collect(ep, score);
- }
- }
+ if (bulkScores == null || bulkScores.length < eps.length) {
+ bulkScores = new float[eps.length];
+ }
+ if (results.earlyTerminated()) {
+ return;
+ }
+ scoreEntryPoints(results, scorer, visited, eps, acceptOrds, candidates, bulkScores);
+ if (results.earlyTerminated()) {
+ return;
}
// Collect the vectors to score and potentially add as candidates
IntArrayQueue toScore = new IntArrayQueue(graph.maxConn() * 2 * maxExplorationMultiplier);
@@ -190,17 +188,29 @@
}
}
// Score the vectors and add them to the candidate list
- int toScoreOrd;
- while ((toScoreOrd = toScore.poll()) != NO_MORE_DOCS) {
- float friendSimilarity = scorer.score(toScoreOrd);
- results.incVisitedCount(1);
- if (friendSimilarity > minAcceptedSimilarity) {
- candidates.add(toScoreOrd, friendSimilarity);
- if (results.collect(toScoreOrd, friendSimilarity)) {
- minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
+ if (bulkScores == null || bulkScores.length < toScore.count()) {
+ bulkScores = new float[toScore.count()];
+ }
+ assert toScore.upto == 0;
+ float maxScore =
+ toScore.count() > 0
+ ? scorer.bulkScore(toScore.nodes, bulkScores, toScore.size)
+ : Float.NEGATIVE_INFINITY;
+ results.incVisitedCount(toScore.count());
+ if (maxScore > minAcceptedSimilarity) {
+ for (int i = 0; i < toScore.count(); i++) {
+ int idx = i + toScore.upto;
+ float friendSimilarity = bulkScores[idx];
+ if (friendSimilarity > minAcceptedSimilarity) {
+ int ord = toScore.nodes[idx];
+ candidates.add(ord, friendSimilarity);
+ if (results.collect(ord, friendSimilarity)) {
+ minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
+ }
}
}
}
+ toScore.upto = toScore.size; // all scored
if (results.getSearchStrategy() != null) {
results.getSearchStrategy().nextVectorsBlock();
}
@@ -213,7 +223,7 @@
}
private static class IntArrayQueue {
- private int[] nodes;
+ private final int[] nodes;
private int upto;
private int size;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 4e94203..1876fe6 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -226,7 +226,7 @@
return new int[] {currentEp};
}
int size = getGraphSize(graph);
- prepareScratchState(size);
+ prepareScratchState(size, graph.maxConn() * 2);
float currentScore = scorer.score(currentEp);
collector.incVisitedCount(1);
boolean foundBetter;
@@ -238,6 +238,7 @@
foundBetter = false;
graphSeek(graph, level, currentEp);
int friendOrd;
+ int numNodes = 0;
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
@@ -246,12 +247,21 @@
if (collector.earlyTerminated()) {
return new int[] {UNK_EP};
}
- float friendSimilarity = scorer.score(friendOrd);
- collector.incVisitedCount(1);
- if (friendSimilarity > currentScore) {
- currentScore = friendSimilarity;
- currentEp = friendOrd;
- foundBetter = true;
+ bulkNodes[numNodes++] = friendOrd;
+ }
+ float maxScore =
+ numNodes > 0
+ ? scorer.bulkScore(bulkNodes, bulkScores, numNodes)
+ : Float.NEGATIVE_INFINITY;
+ collector.incVisitedCount(numNodes);
+ if (maxScore > currentScore) {
+ for (int i = 0; i < numNodes; i++) {
+ float score = bulkScores[i];
+ if (score > currentScore) {
+ currentScore = score;
+ currentEp = bulkNodes[i];
+ foundBetter = true;
+ }
}
}
}
@@ -277,25 +287,16 @@
int size = getGraphSize(graph);
- prepareScratchState(size);
-
- if (bulkNodes == null || bulkNodes.length < graph.maxConn() * 2) {
- bulkNodes = new int[graph.maxConn() * 2];
- bulkScores = new float[graph.maxConn() * 2];
+ prepareScratchState(size, graph.maxConn() * 2);
+ if (bulkScores == null || bulkScores.length < eps.length) {
+ bulkScores = new float[eps.length];
}
-
- for (int ep : eps) {
- if (visited.getAndSet(ep) == false) {
- if (results.earlyTerminated()) {
- break;
- }
- float score = scorer.score(ep);
- results.incVisitedCount(1);
- candidates.add(ep, score);
- if (acceptOrds == null || acceptOrds.get(ep)) {
- results.collect(ep, score);
- }
- }
+ if (results.earlyTerminated()) {
+ return;
+ }
+ scoreEntryPoints(results, scorer, visited, eps, acceptOrds, candidates, bulkScores);
+ if (results.earlyTerminated()) {
+ return;
}
// A bound that holds the minimum similarity to the query vector that a candidate vector must
@@ -335,7 +336,7 @@
bulkNodes[numNodes++] = friendOrd;
}
- numNodes = (int) Math.min((long) numNodes, results.visitLimit() - results.visitedCount());
+ numNodes = (int) Math.min(numNodes, results.visitLimit() - results.visitedCount());
results.incVisitedCount(numNodes);
if (numNodes > 0
&& scorer.bulkScore(bulkNodes, bulkScores, numNodes)
@@ -365,12 +366,16 @@
}
}
- private void prepareScratchState(int capacity) {
+ private void prepareScratchState(int capacity, int bulkScoreSize) {
candidates.clear();
if (visited.length() < capacity) {
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
}
visited.clear();
+ if (bulkNodes == null || bulkNodes.length < bulkScoreSize) {
+ bulkNodes = new int[bulkScoreSize];
+ bulkScores = new float[bulkScoreSize];
+ }
}
/**
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
index ad257cd..dacfeea 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
@@ -1503,10 +1503,13 @@
visitedLimit);
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation());
int size = Lucene99HnswVectorsReader.EXHAUSTIVE_BULK_SCORE_ORDS;
+ // visit limit is a "best effort" limit given our bulk scoring logic; assert that we are
+ // within
+ // reasonable bounds
assertTrue(
- visitedLimit == results.totalHits.value()
- || ((visitedLimit + size - 1) / size) * ((long) size)
- == results.totalHits.value());
+ results.totalHits.value() == visitedLimit
+ || results.totalHits.value()
+ <= ((visitedLimit + size - 1) / size) * ((long) size));
// check the limit is not hit when it clearly exceeds the number of vectors
k = vectorValues.size();