Prevent extra similarity computation for single-level graphs (#12866)
### Description
[`#findBestEntryPoint`](https://github.com/apache/lucene/blob/4bc7850465dfac9dc0638d9ee782007883869ffe/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java#L151) is used to determine the entry point for the last level of HNSW search
It finds the single best-scoring node from [all upper levels](https://github.com/apache/lucene/blob/4bc7850465dfac9dc0638d9ee782007883869ffe/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java#L159) - but performs an [unnecessary computation](https://github.com/apache/lucene/blob/4bc7850465dfac9dc0638d9ee782007883869ffe/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java#L157) (along with [recording one visited node](https://github.com/apache/lucene/blob/4bc7850465dfac9dc0638d9ee782007883869ffe/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java#L154)) when the graph just has 1 level (so the entry node is just the overall graph's entry node)
Also added a test to demonstrate this (fails without the changes in PR) -- where we visit `graph.size() + 1` nodes when the `topK` is high (should be a maximum of `graph.size()`)
---------
Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index eb3a748..4d24255 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -140,6 +140,33 @@
* GITHUB#9049: Fixing bug in UnescapedCharSequence#toStringEscaped() (Jakub Slowinski)
+======================== Lucene 9.10.0 =======================
+
+API Changes
+---------------------
+(No changes)
+
+New Features
+---------------------
+(No changes)
+
+Improvements
+---------------------
+(No changes)
+
+Optimizations
+---------------------
+(No changes)
+
+Bug Fixes
+---------------------
+
+* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
+
+Other
+---------------------
+(No changes)
+
======================== Lucene 9.9.0 =======================
API Changes
diff --git a/lucene/core/src/java/org/apache/lucene/util/Version.java b/lucene/core/src/java/org/apache/lucene/util/Version.java
index f7c0ad4..8c12488 100644
--- a/lucene/core/src/java/org/apache/lucene/util/Version.java
+++ b/lucene/core/src/java/org/apache/lucene/util/Version.java
@@ -112,11 +112,18 @@
/**
* Match settings and bugs in Lucene's 9.9.0 release.
*
- * @deprecated Use latest
+ * @deprecated (9.10.0) Use latest
*/
@Deprecated public static final Version LUCENE_9_9_0 = new Version(9, 9, 0);
/**
+ * Match settings and bugs in Lucene's 9.10.0 release.
+ *
+ * @deprecated Use latest
+ */
+ @Deprecated public static final Version LUCENE_9_10_0 = new Version(9, 10, 0);
+
+ /**
* Match settings and bugs in Lucene's 10.0.0 release.
*
* <p>Use this to get the latest & greatest settings, bug fixes, etc, for Lucene.
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 0135fc5..2aa5389 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -100,19 +100,10 @@
HnswGraphSearcher graphSearcher,
Bits acceptOrds)
throws IOException {
- int initialEp = graph.entryNode();
- if (initialEp == -1) {
- return;
+ int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector);
+ if (ep != -1) {
+ graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
}
- int[] epAndVisited = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector.visitLimit());
- int numVisited = epAndVisited[1];
- int ep = epAndVisited[0];
- if (ep == -1) {
- knnCollector.incVisitedCount(numVisited);
- return;
- }
- knnCollector.incVisitedCount(numVisited);
- graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);
}
/**
@@ -143,18 +134,21 @@
*
* @param scorer the scorer to compare the query with the nodes
* @param graph the HNSWGraph
- * @param visitLimit How many vectors are allowed to be visited
- * @return An integer array whose first element is the best entry point, and second is the number
- * of candidates visited. Entry point of `-1` indicates visitation limit exceed
+ * @param collector the knn result collector
+ * @return the best entry point, `-1` indicates graph entry node not set, or visitation limit
+ * exceeded
* @throws IOException When accessing the vector fails
*/
- private int[] findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, long visitLimit)
+ private int findBestEntryPoint(RandomVectorScorer scorer, HnswGraph graph, KnnCollector collector)
throws IOException {
- int size = getGraphSize(graph);
- int visitedCount = 1;
- prepareScratchState(size);
int currentEp = graph.entryNode();
+ if (currentEp == -1 || graph.numLevels() == 1) {
+ return currentEp;
+ }
+ int size = getGraphSize(graph);
+ prepareScratchState(size);
float currentScore = scorer.score(currentEp);
+ collector.incVisitedCount(1);
boolean foundBetter;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
foundBetter = true;
@@ -169,11 +163,11 @@
if (visited.getAndSet(friendOrd)) {
continue;
}
- if (visitedCount >= visitLimit) {
- return new int[] {-1, visitedCount};
+ if (collector.earlyTerminated()) {
+ return -1;
}
float friendSimilarity = scorer.score(friendOrd);
- visitedCount++;
+ collector.incVisitedCount(1);
if (friendSimilarity > currentScore) {
currentScore = friendSimilarity;
currentEp = friendOrd;
@@ -182,7 +176,7 @@
}
}
}
- return new int[] {currentEp, visitedCount};
+ return collector.earlyTerminated() ? -1 : currentEp;
}
/**
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index b943d3a..0cc1007 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -70,6 +70,7 @@
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -1026,6 +1027,37 @@
}
}
+ public void testAllNodesVisitedInSingleLevel() throws IOException {
+ int size = atLeast(100);
+ int dim = atLeast(50);
+
+ // Search for a large number of results
+ int topK = size - 1;
+
+ AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
+ HnswGraph graph =
+ HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
+ .build(size);
+
+ HnswGraph singleLevelGraph =
+ new DelegateHnswGraph(graph) {
+ @Override
+ public int numLevels() {
+ // Only retain the last level
+ return 1;
+ }
+ };
+
+ AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
+ RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
+
+ KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
+ HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
+
+ // Check that we visit all nodes
+ assertEquals(graph.size(), collector.visitedCount());
+ }
+
private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
@@ -1297,4 +1329,42 @@
return sb.toString();
}
+
+ private static class DelegateHnswGraph extends HnswGraph {
+ final HnswGraph delegate;
+
+ DelegateHnswGraph(HnswGraph delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void seek(int level, int target) throws IOException {
+ delegate.seek(level, target);
+ }
+
+ @Override
+ public int size() {
+ return delegate.size();
+ }
+
+ @Override
+ public int nextNeighbor() throws IOException {
+ return delegate.nextNeighbor();
+ }
+
+ @Override
+ public int numLevels() throws IOException {
+ return delegate.numLevels();
+ }
+
+ @Override
+ public int entryNode() throws IOException {
+ return delegate.entryNode();
+ }
+
+ @Override
+ public NodesIterator getNodesOnLevel(int level) throws IOException {
+ return delegate.getNodesOnLevel(level);
+ }
+ }
}