Merge branch 'speed_float_cosine'
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 13840a9..3f7d9fe 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -167,6 +167,12 @@
* GITHUB#12685: Lucene now records if documents have been indexed as blocks in SegmentInfo. This is recorded on a per
segment basis and maintained across merges. The property is exposed via LeafReaderMetadata. (Simon Willnauer)
+* GITHUB#12582: Add int8 scalar quantization to the HNSW vector format. This optionally allows for more compact lossy
+ storage for the vectors, requiring about 75% memory for fast HNSW search. (Ben Trent)
+
+* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
+ (Patrick Zhai)
+
Improvements
---------------------
* GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions
@@ -236,6 +242,11 @@
* GITHUB#12702: Disable suffix sharing for block tree index, making writing the terms dictionary index faster
and less RAM hungry, while making the index a bit (~1.X% for the terms index file on wikipedia). (Guo Feng, Mike McCandless)
+* GITHUB#12726: Return the same input vector if its a unit vector in VectorUtil#l2normalize. (Shubham Chaudhary)
+
+* GITHUB#12719: Top-level conjunctions that are not sorted by score now have a
+ specialized bulk scorer. (Adrien Grand)
+
Changes in runtime behavior
---------------------
@@ -256,6 +267,8 @@
* GITHUB#12682: Scorer should sum up scores into a double. (Shubham Chaudhary)
+* GITHUB#12727: Ensure negative scores are not returned by vector similarity functions (Ben Trent)
+
Build
---------------------
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
index 7ac2470..415d1ca 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
@@ -476,9 +476,9 @@
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
- HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
- hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
- graph = hnswGraphBuilder.build(docsWithField.cardinality());
+ graph =
+ merger.merge(
+ mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java
index be7f84f..038b75e 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java
@@ -18,6 +18,7 @@
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
+import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@@ -151,6 +152,9 @@
*/
public static final int DEFAULT_BEAM_WIDTH = 100;
+ /** Default to use single thread merge */
+ public static final int DEFAULT_NUM_MERGE_WORKER = 1;
+
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
/**
@@ -169,20 +173,46 @@
/** Should this codec scalar quantize float32 vectors and use this format */
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
+ private final int numMergeWorkers;
+ private final ExecutorService mergeExec;
+
/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
}
+ public Lucene99HnswVectorsFormat(
+ int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
+ this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
+ */
+ public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
+ this(maxConn, beamWidth, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters and scalar quantization.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
+ * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
+ * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
+ * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
+ * generated by this format to do the merge
*/
public Lucene99HnswVectorsFormat(
- int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
+ int maxConn,
+ int beamWidth,
+ Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
+ int numMergeWorkers,
+ ExecutorService mergeExec) {
super("Lucene99HnswVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
@@ -198,14 +228,25 @@
+ "; beamWidth="
+ beamWidth);
}
+ if (numMergeWorkers > 1 && mergeExec == null) {
+ throw new IllegalArgumentException(
+ "No executor service passed in when " + numMergeWorkers + " merge workers are requested");
+ }
+ if (numMergeWorkers == 1 && mergeExec != null) {
+ throw new IllegalArgumentException(
+ "No executor service is needed as we'll use single thread to merge");
+ }
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.scalarQuantizedVectorsFormat = scalarQuantize;
+ this.numMergeWorkers = numMergeWorkers;
+ this.mergeExec = mergeExec;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
- return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
+ return new Lucene99HnswVectorsWriter(
+ state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
index e3fd543..5515de2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
@@ -28,6 +28,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
@@ -52,9 +53,11 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
+import org.apache.lucene.util.hnsw.HnswGraphMerger;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
@@ -75,6 +78,8 @@
private final int M;
private final int beamWidth;
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
+ private final int numMergeWorkers;
+ private final ExecutorService mergeExec;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
@@ -83,10 +88,14 @@
SegmentWriteState state,
int M,
int beamWidth,
- Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
+ Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat,
+ int numMergeWorkers,
+ ExecutorService mergeExec)
throws IOException {
this.M = M;
this.beamWidth = beamWidth;
+ this.numMergeWorkers = numMergeWorkers;
+ this.mergeExec = mergeExec;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
@@ -383,7 +392,7 @@
int node = nodesOnLevel0.nextInt();
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
- reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
+ reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
@@ -400,7 +409,7 @@
for (int node : newNodes) {
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
- reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
+ reconstructAndWriteNeighbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[level][nodeOffsetIndex++] =
Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
@@ -442,7 +451,7 @@
};
}
- private void reconstructAndWriteNeigbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
+ private void reconstructAndWriteNeighbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
throws IOException {
int size = neighbors.size();
vectorIndex.writeVInt(size);
@@ -557,6 +566,12 @@
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempFileName);
}
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ // here we just return the inner out since we only need to close this outside copy
+ return innerScoreSupplier.copy();
+ }
};
} else {
// No need to use temporary file as we don't have to re-open for reading
@@ -579,8 +594,7 @@
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
- IncrementalHnswGraphMerger merger =
- new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
+ HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
@@ -592,9 +606,9 @@
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
- HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
- hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
- graph = hnswGraphBuilder.build(docsWithField.cardinality());
+ graph =
+ merger.merge(
+ mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@@ -675,6 +689,15 @@
return sortedNodes;
}
+ private HnswGraphMerger createGraphMerger(
+ FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
+ if (mergeExec != null) {
+ return new ConcurrentHnswMerger(
+ fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
+ }
+ return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
+ }
+
private void writeMeta(
boolean isQuantized,
FieldInfo field,
@@ -819,6 +842,9 @@
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
+ if (mergeExec != null) {
+ mergeExec.shutdownNow();
+ }
}
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java
index c05141c..c048581 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java
@@ -52,6 +52,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/**
* Writes quantized vector values and metadata to index segments.
@@ -233,7 +234,8 @@
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
- writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
+ DocsWithFieldSet docsWithField =
+ writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
CodecUtil.writeFooter(tempQuantizedVectorData);
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
@@ -253,7 +255,9 @@
fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState,
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
- fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput)));
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ quantizationDataInput)));
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(quantizationDataInput);
@@ -762,6 +766,11 @@
}
@Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return supplier.copy();
+ }
+
+ @Override
public void close() throws IOException {
onClose.close();
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorerSupplier.java
index bb9eacf..2d01bfa 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorerSupplier.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorerSupplier.java
@@ -39,6 +39,12 @@
this.values = values;
}
+ private ScalarQuantizedRandomVectorScorerSupplier(
+ ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) {
+ this.similarity = similarity;
+ this.values = values;
+ }
+
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
@@ -46,4 +52,9 @@
final float queryOffset = values.getScoreCorrectionConstant();
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
}
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy());
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
index ae0633e..23bd0fd 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
@@ -52,7 +52,7 @@
DOT_PRODUCT {
@Override
public float compare(float[] v1, float[] v2) {
- return (1 + dotProduct(v1, v2)) / 2;
+ return Math.max((1 + dotProduct(v1, v2)) / 2, 0);
}
@Override
@@ -70,7 +70,7 @@
COSINE {
@Override
public float compare(float[] v1, float[] v2) {
- return (1 + cosine(v1, v2)) / 2;
+ return Math.max((1 + cosine(v1, v2)) / 2, 0);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
index aa4dd70..d550076 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
@@ -317,6 +317,12 @@
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new BlockMaxConjunctionBulkScorer(context.reader().maxDoc(), requiredScoring);
}
+ if (scoreMode != ScoreMode.TOP_SCORES
+ && requiredScoring.size() + requiredNoScoring.size() >= 2
+ && requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
+ && requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
+ return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
+ }
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring =
Collections.singletonList(new BlockMaxConjunctionScorer(this, requiredScoring));
diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionBulkScorer.java
new file mode 100644
index 0000000..04ad472
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionBulkScorer.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import org.apache.lucene.util.Bits;
+
+/**
+ * BulkScorer implementation of {@link ConjunctionScorer}. For simplicity, it focuses on scorers
+ * that produce regular {@link DocIdSetIterator}s and not {@link TwoPhaseIterator}s.
+ */
+final class ConjunctionBulkScorer extends BulkScorer {
+
+ private final Scorer[] scoringScorers;
+ private final DocIdSetIterator lead1, lead2;
+ private final List<DocIdSetIterator> others;
+ private final Scorable scorable;
+
+ ConjunctionBulkScorer(List<Scorer> requiredScoring, List<Scorer> requiredNoScoring)
+ throws IOException {
+ final int numClauses = requiredScoring.size() + requiredNoScoring.size();
+ if (numClauses <= 1) {
+ throw new IllegalArgumentException("Expected 2 or more clauses, got " + numClauses);
+ }
+ List<Scorer> allScorers = new ArrayList<>();
+ allScorers.addAll(requiredScoring);
+ allScorers.addAll(requiredNoScoring);
+
+ this.scoringScorers = requiredScoring.toArray(Scorer[]::new);
+ List<DocIdSetIterator> iterators = new ArrayList<>();
+ for (Scorer scorer : allScorers) {
+ iterators.add(scorer.iterator());
+ }
+ Collections.sort(iterators, Comparator.comparingLong(DocIdSetIterator::cost));
+ lead1 = iterators.get(0);
+ lead2 = iterators.get(1);
+ others = List.copyOf(iterators.subList(2, iterators.size()));
+ scorable =
+ new Scorable() {
+ @Override
+ public float score() throws IOException {
+ double score = 0;
+ for (Scorer scorer : scoringScorers) {
+ score += scorer.score();
+ }
+ return (float) score;
+ }
+
+ @Override
+ public Collection<ChildScorable> getChildren() throws IOException {
+ ArrayList<ChildScorable> children = new ArrayList<>();
+ for (Scorer scorer : allScorers) {
+ children.add(new ChildScorable(scorer, "MUST"));
+ }
+ return children;
+ }
+ };
+ }
+
+ @Override
+ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
+ assert lead1.docID() >= lead2.docID();
+
+ if (lead1.docID() < min) {
+ lead1.advance(min);
+ }
+
+ if (lead1.docID() >= max) {
+ return lead1.docID();
+ }
+
+ collector.setScorer(scorable);
+
+ List<DocIdSetIterator> otherIterators = this.others;
+ DocIdSetIterator collectorIterator = collector.competitiveIterator();
+ if (collectorIterator != null) {
+ otherIterators = new ArrayList<>(otherIterators);
+ otherIterators.add(collectorIterator);
+ }
+
+ final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new);
+
+ // In the main for loop, we want to be able to rely on the invariant that lead1.docID() >
+ // lead2.doc(). However it's possible that these two are equal on the first document in a
+ // scoring window. So we treat this case separately here.
+ if (lead1.docID() == lead2.docID()) {
+ final int doc = lead1.docID();
+ if (acceptDocs == null || acceptDocs.get(doc)) {
+ boolean match = true;
+ for (DocIdSetIterator it : others) {
+ if (it.docID() < doc) {
+ int next = it.advance(doc);
+ if (next != doc) {
+ lead1.advance(next);
+ match = false;
+ break;
+ }
+ }
+ assert it.docID() == doc;
+ }
+
+ if (match) {
+ collector.collect(doc);
+ lead1.nextDoc();
+ }
+ } else {
+ lead1.nextDoc();
+ }
+ }
+
+ advanceHead:
+ for (int doc = lead1.docID(); doc < max; ) {
+ assert lead2.docID() < doc;
+
+ if (acceptDocs != null && acceptDocs.get(doc) == false) {
+ doc = lead1.nextDoc();
+ continue;
+ }
+
+ // We maintain the invariant that lead2.docID() < lead1.docID() so that we don't need to check
+ // if lead2 is already on the same doc as lead1 here.
+ int next2 = lead2.advance(doc);
+ if (next2 != doc) {
+ doc = lead1.advance(next2);
+ if (doc != next2) {
+ continue;
+ } else if (doc >= max) {
+ break;
+ } else if (acceptDocs != null && acceptDocs.get(doc) == false) {
+ doc = lead1.nextDoc();
+ continue;
+ }
+ }
+ assert lead2.docID() == doc;
+
+ for (DocIdSetIterator it : others) {
+ if (it.docID() < doc) {
+ int next = it.advance(doc);
+ if (next != doc) {
+ doc = lead1.advance(next);
+ continue advanceHead;
+ }
+ }
+ assert it.docID() == doc;
+ }
+
+ collector.collect(doc);
+ doc = lead1.nextDoc();
+ }
+
+ return lead1.docID();
+ }
+
+ @Override
+ public long cost() {
+ return lead1.cost();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index a978e2d..ef52e605 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -106,18 +106,21 @@
* @throws IllegalArgumentException when the vector is all zero and throwOnZero is true
*/
public static float[] l2normalize(float[] v, boolean throwOnZero) {
- double squareSum = IMPL.dotProduct(v, v);
- int dim = v.length;
- if (squareSum == 0) {
+ double l1norm = IMPL.dotProduct(v, v);
+ if (l1norm == 0) {
if (throwOnZero) {
throw new IllegalArgumentException("Cannot normalize a zero-length vector");
} else {
return v;
}
}
- double length = Math.sqrt(squareSum);
+ if (Math.abs(l1norm - 1.0d) <= 1e-5) {
+ return v;
+ }
+ int dim = v.length;
+ double l2norm = Math.sqrt(l1norm);
for (int i = 0; i < dim; i++) {
- v[i] /= length;
+ v[i] /= (float) l2norm;
}
return v;
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/StringsToAutomaton.java b/lucene/core/src/java/org/apache/lucene/util/automaton/StringsToAutomaton.java
index 6c66dc6..3cfe945 100644
--- a/lucene/core/src/java/org/apache/lucene/util/automaton/StringsToAutomaton.java
+++ b/lucene/core/src/java/org/apache/lucene/util/automaton/StringsToAutomaton.java
@@ -269,7 +269,7 @@
throw new IllegalArgumentException(
"This builder doesn't allow terms that are larger than "
+ Automata.MAX_STRING_UNION_TERM_LENGTH
- + " characters, got "
+ + " UTF-8 bytes, got "
+ current);
}
assert stateRegistry != null : "Automaton already built.";
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java
index 9ef4fa0..1490624 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java
@@ -22,6 +22,9 @@
/**
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
* close after use
+ *
+ * <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
+ * closeable
*/
public interface CloseableRandomVectorScorerSupplier
extends Closeable, RandomVectorScorerSupplier {}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java
new file mode 100644
index 0000000..2253e73
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.hnsw;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import org.apache.lucene.codecs.HnswGraphProvider;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.util.BitSet;
+import org.apache.lucene.util.FixedBitSet;
+
+/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
+public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
+
+ private final ExecutorService exec;
+ private final int numWorker;
+
+ /**
+ * @param fieldInfo FieldInfo for the field being merged
+ */
+ public ConcurrentHnswMerger(
+ FieldInfo fieldInfo,
+ RandomVectorScorerSupplier scorerSupplier,
+ int M,
+ int beamWidth,
+ ExecutorService exec,
+ int numWorker) {
+ super(fieldInfo, scorerSupplier, M, beamWidth);
+ this.exec = exec;
+ this.numWorker = numWorker;
+ }
+
+ @Override
+ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
+ throws IOException {
+ if (initReader == null) {
+ return new HnswConcurrentMergeBuilder(
+ exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
+ }
+
+ HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
+ BitSet initializedNodes = new FixedBitSet(maxOrd);
+ int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
+
+ return new HnswConcurrentMergeBuilder(
+ exec,
+ numWorker,
+ scorerSupplier,
+ M,
+ beamWidth,
+ InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd),
+ initializedNodes);
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java
new file mode 100644
index 0000000..5473856
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswBuilder.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import java.io.IOException;
+import org.apache.lucene.util.InfoStream;
+
+/**
+ * Interface for builder building the {@link OnHeapHnswGraph}
+ *
+ * @lucene.experimental
+ */
+public interface HnswBuilder {
+
+ /**
+ * Adds all nodes to the graph up to the provided {@code maxOrd}.
+ *
+ * @param maxOrd The maximum ordinal (excluded) of the nodes to be added.
+ */
+ OnHeapHnswGraph build(int maxOrd) throws IOException;
+
+ /** Inserts a doc with vector value to the graph */
+ void addGraphNode(int node) throws IOException;
+
+ /** Set info-stream to output debugging information */
+ void setInfoStream(InfoStream infoStream);
+
+ OnHeapHnswGraph getGraph();
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java
new file mode 100644
index 0000000..27e555a
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.hnsw;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.hnsw.HnswGraphBuilder.HNSW_COMPONENT;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.lucene.util.BitSet;
+import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.InfoStream;
+import org.apache.lucene.util.ThreadInterruptedException;
+
+/**
+ * A graph builder that manages multiple workers, it only supports adding the whole graph all at
+ * once. It will spawn a thread for each worker and the workers will pick the work in batches.
+ */
+public class HnswConcurrentMergeBuilder implements HnswBuilder {
+
+ private static final int DEFAULT_BATCH_SIZE =
+ 2048; // number of vectors the worker handles sequentially at one batch
+
+ private final ExecutorService exec;
+ private final ConcurrentMergeWorker[] workers;
+ private InfoStream infoStream = InfoStream.getDefault();
+
+ public HnswConcurrentMergeBuilder(
+ ExecutorService exec,
+ int numWorker,
+ RandomVectorScorerSupplier scorerSupplier,
+ int M,
+ int beamWidth,
+ OnHeapHnswGraph hnsw,
+ BitSet initializedNodes)
+ throws IOException {
+ this.exec = exec;
+ AtomicInteger workProgress = new AtomicInteger(0);
+ workers = new ConcurrentMergeWorker[numWorker];
+ for (int i = 0; i < numWorker; i++) {
+ workers[i] =
+ new ConcurrentMergeWorker(
+ scorerSupplier.copy(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed,
+ hnsw,
+ initializedNodes,
+ workProgress);
+ }
+ }
+
+ @Override
+ public OnHeapHnswGraph build(int maxOrd) throws IOException {
+ if (infoStream.isEnabled(HNSW_COMPONENT)) {
+ infoStream.message(
+ HNSW_COMPONENT,
+ "build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
+ }
+ List<Future<?>> futures = new ArrayList<>();
+ for (int i = 0; i < workers.length; i++) {
+ int finalI = i;
+ futures.add(
+ exec.submit(
+ () -> {
+ try {
+ workers[finalI].run(maxOrd);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }));
+ }
+ Throwable exc = null;
+ for (Future<?> future : futures) {
+ try {
+ future.get();
+ } catch (InterruptedException e) {
+ var newException = new ThreadInterruptedException(e);
+ if (exc == null) {
+ exc = newException;
+ } else {
+ exc.addSuppressed(newException);
+ }
+ } catch (ExecutionException e) {
+ if (exc == null) {
+ exc = e.getCause();
+ } else {
+ exc.addSuppressed(e.getCause());
+ }
+ }
+ }
+ if (exc != null) {
+ // The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
+ throw IOUtils.rethrowAlways(exc);
+ }
+ return workers[0].getGraph();
+ }
+
+ @Override
+ public void addGraphNode(int node) throws IOException {
+ throw new UnsupportedOperationException("This builder is for merge only");
+ }
+
+ @Override
+ public void setInfoStream(InfoStream infoStream) {
+ this.infoStream = infoStream;
+ for (HnswBuilder worker : workers) {
+ worker.setInfoStream(infoStream);
+ }
+ }
+
+ @Override
+ public OnHeapHnswGraph getGraph() {
+ return workers[0].getGraph();
+ }
+
+ /* test only for now */
+ void setBatchSize(int newSize) {
+ for (ConcurrentMergeWorker worker : workers) {
+ worker.batchSize = newSize;
+ }
+ }
+
+ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
+
+ /**
+ * A common AtomicInteger shared among all workers, used for tracking what's the next vector to
+ * be added to the graph.
+ */
+ private final AtomicInteger workProgress;
+
+ private final BitSet initializedNodes;
+ private int batchSize = DEFAULT_BATCH_SIZE;
+
+ private ConcurrentMergeWorker(
+ RandomVectorScorerSupplier scorerSupplier,
+ int M,
+ int beamWidth,
+ long seed,
+ OnHeapHnswGraph hnsw,
+ BitSet initializedNodes,
+ AtomicInteger workProgress)
+ throws IOException {
+ super(
+ scorerSupplier,
+ M,
+ beamWidth,
+ seed,
+ hnsw,
+ new MergeSearcher(
+ new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.maxNodeId() + 1)));
+ this.workProgress = workProgress;
+ this.initializedNodes = initializedNodes;
+ }
+
+ /**
+ * This method first try to "reserve" part of work by calling {@link #getStartPos(int)} and then
+ * calling {@link #addVectors(int, int)} to actually add the nodes to the graph. By doing this
+ * we are able to dynamically allocate the work to multiple workers and try to make all of them
+ * finishing around the same time.
+ */
+ private void run(int maxOrd) throws IOException {
+ int start = getStartPos(maxOrd);
+ int end;
+ while (start != -1) {
+ end = Math.min(maxOrd, start + batchSize);
+ addVectors(start, end);
+ start = getStartPos(maxOrd);
+ }
+ }
+
+ /** Reserve the work by atomically increment the {@link #workProgress} */
+ private int getStartPos(int maxOrd) {
+ int start = workProgress.getAndAdd(batchSize);
+ if (start < maxOrd) {
+ return start;
+ } else {
+ return -1;
+ }
+ }
+
+ @Override
+ public void addGraphNode(int node) throws IOException {
+ if (initializedNodes != null && initializedNodes.get(node)) {
+ return;
+ }
+ super.addGraphNode(node);
+ }
+ }
+
+ /**
+ * This searcher will obtain the lock and make a copy of neighborArray when seeking the graph such
+ * that concurrent modification of the graph will not impact the search
+ */
+ private static class MergeSearcher extends HnswGraphSearcher {
+ private int[] nodeBuffer;
+ private int upto;
+ private int size;
+
+ private MergeSearcher(NeighborQueue candidates, BitSet visited) {
+ super(candidates, visited);
+ }
+
+ @Override
+ void graphSeek(HnswGraph graph, int level, int targetNode) {
+ NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
+ neighborArray.rwlock.readLock().lock();
+ try {
+ if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) {
+ nodeBuffer = new int[neighborArray.size()];
+ }
+ size = neighborArray.size();
+ if (size >= 0) System.arraycopy(neighborArray.node, 0, nodeBuffer, 0, size);
+ } finally {
+ neighborArray.rwlock.readLock().unlock();
+ }
+ upto = -1;
+ }
+
+ @Override
+ int graphNextNeighbor(HnswGraph graph) {
+ if (++upto < size) {
+ return nodeBuffer[upto];
+ }
+ return NO_MORE_DOCS;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d1f8c3d..a178106 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -33,7 +33,7 @@
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyper-parameters.
*/
-public class HnswGraphBuilder {
+public class HnswGraphBuilder implements HnswBuilder {
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
@@ -54,7 +54,6 @@
private final int M; // max number of connections on upper layers
private final double ml;
- private final NeighborArray scratch;
private final SplittableRandom random;
private final RandomVectorScorerSupplier scorerSupplier;
@@ -97,6 +96,22 @@
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
}
+ protected HnswGraphBuilder(
+ RandomVectorScorerSupplier scorerSupplier,
+ int M,
+ int beamWidth,
+ long seed,
+ OnHeapHnswGraph hnsw)
+ throws IOException {
+ this(
+ scorerSupplier,
+ M,
+ beamWidth,
+ seed,
+ hnsw,
+ new HnswGraphSearcher(new NeighborQueue(beamWidth, true), new FixedBitSet(hnsw.size())));
+ }
+
/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
@@ -114,7 +129,8 @@
int M,
int beamWidth,
long seed,
- OnHeapHnswGraph hnsw)
+ OnHeapHnswGraph hnsw,
+ HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
@@ -129,20 +145,12 @@
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
this.hnsw = hnsw;
- this.graphSearcher =
- new HnswGraphSearcher(
- new NeighborQueue(beamWidth, true), new FixedBitSet(this.getGraph().size()));
- // in scratch we store candidates in reverse order: worse candidates are first
- scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
+ this.graphSearcher = graphSearcher;
entryCandidates = new GraphBuilderKnnCollector(1);
beamCandidates = new GraphBuilderKnnCollector(beamWidth);
}
- /**
- * Adds all nodes to the graph up to the provided {@code maxOrd}.
- *
- * @param maxOrd The maximum ordinal of the nodes to be added.
- */
+ @Override
public OnHeapHnswGraph build(int maxOrd) throws IOException {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + maxOrd + " vectors");
@@ -151,18 +159,23 @@
return hnsw;
}
- /** Set info-stream to output debugging information * */
+ @Override
public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
}
+ @Override
public OnHeapHnswGraph getGraph() {
return hnsw;
}
- private void addVectors(int maxOrd) throws IOException {
+ /** add vectors in range [minOrd, maxOrd) */
+ protected void addVectors(int minOrd, int maxOrd) throws IOException {
long start = System.nanoTime(), t = start;
- for (int node = 0; node < maxOrd; node++) {
+ if (infoStream.isEnabled(HNSW_COMPONENT)) {
+ infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
+ }
+ for (int node = minOrd; node < maxOrd; node++) {
addGraphNode(node);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
@@ -170,42 +183,98 @@
}
}
- /** Inserts a doc with vector value to the graph */
+ private void addVectors(int maxOrd) throws IOException {
+ addVectors(0, maxOrd);
+ }
+
+ @Override
public void addGraphNode(int node) throws IOException {
+ /*
+ Note: this implementation is thread safe when graph size is fixed (e.g. when merging)
+ The process of adding a node is roughly:
+ 1. Add the node to all level from top to the bottom, but do not connect it to any other node,
+ nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty
+ and this is the first node, in that case we set the entry node and return)
+ 2. Do the search from top to bottom, remember all the possible neighbours on each level the node
+ is on.
+ 3. Add the neighbor to the node from bottom to top level, when adding the neighbour,
+ we always add all the outgoing links first before adding incoming link such that
+ when a search visits this node, it can always find a way out
+ 4. If the node has level that is less or equal to graph level, then we're done here.
+ If the node has level larger than graph level, then we need to promote the node
+ as the entry node. If, while we add the node to the graph, the entry node has changed
+ (which means the graph level has changed as well), we need to reinsert the node
+ to the newly introduced levels (repeating step 2,3 for new levels) and again try to
+ promote the node to entry node.
+ */
RandomVectorScorer scorer = scorerSupplier.scorer(node);
final int nodeLevel = getRandomGraphLevel(ml, random);
- int curMaxLevel = hnsw.numLevels() - 1;
-
- // If entrynode is -1, then this should finish without adding neighbors
- if (hnsw.entryNode() == -1) {
- for (int level = nodeLevel; level >= 0; level--) {
- hnsw.addNode(level, node);
- }
+ // first add nodes to all levels
+ for (int level = nodeLevel; level >= 0; level--) {
+ hnsw.addNode(level, node);
+ }
+ // then promote itself as entry node if entry node is not set
+ if (hnsw.trySetNewEntryNode(node, nodeLevel)) {
return;
}
- int[] eps = new int[] {hnsw.entryNode()};
+ // if the entry node is already set, then we have to do all connections first before we can
+ // promote ourselves as entry node
- // if a node introduces new levels to the graph, add this new node on new levels
- for (int level = nodeLevel; level > curMaxLevel; level--) {
- hnsw.addNode(level, node);
- }
+ int lowestUnsetLevel = 0;
+ int curMaxLevel;
+ do {
+ curMaxLevel = hnsw.numLevels() - 1;
+ // NOTE: the entry node and max level may not be paired, but because we get the level first
+ // we ensure that the entry node we get later will always exist on the curMaxLevel
+ int[] eps = new int[] {hnsw.entryNode()};
- // for levels > nodeLevel search with topk = 1
- GraphBuilderKnnCollector candidates = entryCandidates;
- for (int level = curMaxLevel; level > nodeLevel; level--) {
- candidates.clear();
- graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
- eps = new int[] {candidates.popNode()};
- }
- // for levels <= nodeLevel search with topk = beamWidth, and add connections
- candidates = beamCandidates;
- for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
- candidates.clear();
- graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
- eps = candidates.popUntilNearestKNodes();
- hnsw.addNode(level, node);
- addDiverseNeighbors(level, node, candidates);
- }
+ // we first do the search from top to bottom
+ // for levels > nodeLevel search with topk = 1
+ GraphBuilderKnnCollector candidates = entryCandidates;
+ for (int level = curMaxLevel; level > nodeLevel; level--) {
+ candidates.clear();
+ graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
+ eps[0] = candidates.popNode();
+ }
+
+ // for levels <= nodeLevel search with topk = beamWidth, and add connections
+ candidates = beamCandidates;
+ NeighborArray[] scratchPerLevel =
+ new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1];
+ for (int i = scratchPerLevel.length - 1; i >= 0; i--) {
+ int level = i + lowestUnsetLevel;
+ candidates.clear();
+ graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null);
+ eps = candidates.popUntilNearestKNodes();
+ scratchPerLevel[i] = new NeighborArray(Math.max(beamCandidates.k(), M + 1), false);
+ popToScratch(candidates, scratchPerLevel[i]);
+ }
+
+ // then do connections from bottom up
+ for (int i = 0; i < scratchPerLevel.length; i++) {
+ addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]);
+ }
+ lowestUnsetLevel += scratchPerLevel.length;
+ assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
+ if (lowestUnsetLevel > nodeLevel) {
+ return;
+ }
+ assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel;
+ if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) {
+ return;
+ }
+ if (hnsw.numLevels() == curMaxLevel + 1) {
+ // This should never happen if all the calculations are correct
+ throw new IllegalStateException(
+ "We're not able to promote node "
+ + node
+ + " at level "
+ + nodeLevel
+ + " as entry node. But the max graph level "
+ + curMaxLevel
+ + " has not changed while we are inserting the node.");
+ }
+ } while (true);
}
private long printGraphBuildStatus(int node, long start, long t) {
@@ -221,7 +290,7 @@
return now;
}
- private void addDiverseNeighbors(int level, int node, GraphBuilderKnnCollector candidates)
+ private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
@@ -229,26 +298,40 @@
*/
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node
- popToScratch(candidates);
int maxConnOnLevel = level == 0 ? M * 2 : M;
- selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel);
+ boolean[] mask = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
- int size = neighbors.size();
- for (int i = 0; i < size; i++) {
- int nbr = neighbors.node[i];
+ // NOTE: here we're using candidates and mask but not the neighbour array because once we have
+ // added incoming link there will be possibilities of this node being discovered and neighbour
+ // array being modified. So using local candidates and mask is a safer option.
+ for (int i = 0; i < candidates.size(); i++) {
+ if (mask[i] == false) {
+ continue;
+ }
+ int nbr = candidates.node[i];
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
- nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
- if (nbrsOfNbr.size() > maxConnOnLevel) {
- int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
- nbrsOfNbr.removeIndex(indexToRemove);
+ nbrsOfNbr.rwlock.writeLock().lock();
+ try {
+ nbrsOfNbr.addOutOfOrder(node, candidates.score[i]);
+ if (nbrsOfNbr.size() > maxConnOnLevel) {
+ int indexToRemove = findWorstNonDiverse(nbrsOfNbr, nbr);
+ nbrsOfNbr.removeIndex(indexToRemove);
+ }
+ } finally {
+ nbrsOfNbr.rwlock.writeLock().unlock();
}
}
}
- private void selectAndLinkDiverse(
+ /**
+ * This method will select neighbors to add and return a mask telling the caller which candidates
+ * are selected
+ */
+ private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
+ boolean[] mask = new boolean[candidates.size()];
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
@@ -257,12 +340,16 @@
float cScore = candidates.score[i];
assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) {
+ mask[i] = true;
+ // here we don't need to lock, because there's no incoming link so no others is able to
+ // discover this node such that no others will modify this neighbor array as well
neighbors.addInOrder(cNode, cScore);
}
}
+ return mask;
}
- private void popToScratch(GraphBuilderKnnCollector candidates) {
+ private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
scratch.clear();
int candidateCount = candidates.size();
// extract all the Neighbors from the queue into an array; these will now be
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java
new file mode 100644
index 0000000..7ed5dd1
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.hnsw;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.InfoStream;
+
+/**
+ * Abstraction of merging multiple graphs into one on-heap graph
+ *
+ * @lucene.experimental
+ */
+public interface HnswGraphMerger {
+
+ /**
+ * Adds a reader to the graph merger to record the state
+ *
+ * @param reader KnnVectorsReader to add to the merger
+ * @param docMap MergeState.DocMap for the reader
+ * @param liveDocs Bits representing live docs, can be null
+ * @return this
+ * @throws IOException If an error occurs while reading from the merge state
+ */
+ HnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs)
+ throws IOException;
+
+ /**
+ * Merge and produce the on heap graph
+ *
+ * @param mergedVectorIterator iterator over the vectors in the merged segment
+ * @param infoStream optional info stream to set to builder
+ * @param maxOrd max number of vectors that will be added to the graph
+ * @return merged graph
+ * @throws IOException during merge
+ */
+ OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd)
+ throws IOException;
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
index 5ddb239..ddcbfda 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
@@ -32,6 +32,7 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.InfoStream;
/**
* This selects the biggest Hnsw graph from the provided merge state and initializes a new
@@ -39,15 +40,16 @@
*
* @lucene.experimental
*/
-public class IncrementalHnswGraphMerger {
+public class IncrementalHnswGraphMerger implements HnswGraphMerger {
- private KnnVectorsReader initReader;
- private MergeState.DocMap initDocMap;
- private int initGraphSize;
- private final FieldInfo fieldInfo;
- private final RandomVectorScorerSupplier scorerSupplier;
- private final int M;
- private final int beamWidth;
+ protected final FieldInfo fieldInfo;
+ protected final RandomVectorScorerSupplier scorerSupplier;
+ protected final int M;
+ protected final int beamWidth;
+
+ protected KnnVectorsReader initReader;
+ protected MergeState.DocMap initDocMap;
+ protected int initGraphSize;
/**
* @param fieldInfo FieldInfo for the field being merged
@@ -64,13 +66,8 @@
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
* deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
* previous reader that met the above criteria
- *
- * @param reader KnnVectorsReader to add to the merger
- * @param docMap MergeState.DocMap for the reader
- * @param liveDocs Bits representing live docs, can be null
- * @return this
- * @throws IOException If an error occurs while reading from the merge state
*/
+ @Override
public IncrementalHnswGraphMerger addReader(
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
KnnVectorsReader currKnnVectorsReader = reader;
@@ -113,18 +110,20 @@
* If no valid readers were added to the merge state, a new graph is created.
*
* @param mergedVectorIterator iterator over the vectors in the merged segment
+ * @param maxOrd max num of vectors that will be merged into the graph
* @return HnswGraphBuilder
* @throws IOException If an error occurs while reading from the merge state
*/
- public HnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator) throws IOException {
+ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
+ throws IOException {
if (initReader == null) {
- return HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
+ return HnswGraphBuilder.create(
+ scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed, maxOrd);
}
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
- final int numVectors = Math.toIntExact(mergedVectorIterator.cost());
- BitSet initializedNodes = new FixedBitSet(numVectors + 1);
+ BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
return InitializedHnswGraphBuilder.fromGraph(
scorerSupplier,
@@ -134,7 +133,15 @@
initializerGraph,
oldToNewOrdinalMap,
initializedNodes,
- numVectors);
+ maxOrd);
+ }
+
+ @Override
+ public OnHeapHnswGraph merge(
+ DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException {
+ HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd);
+ builder.setInfoStream(infoStream);
+ return builder.build(maxOrd);
}
/**
@@ -146,8 +153,8 @@
* @return the mapping from old ordinals to new ordinals
* @throws IOException If an error occurs while reading from the merge state
*/
- private int[] getNewOrdMapping(DocIdSetIterator mergedVectorIterator, BitSet initializedNodes)
- throws IOException {
+ protected final int[] getNewOrdMapping(
+ DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException {
DocIdSetIterator initializerIterator = null;
switch (fieldInfo.getVectorEncoding()) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
index 025fdff..179d243 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
@@ -55,6 +55,18 @@
BitSet initializedNodes,
int totalNumberOfVectors)
throws IOException {
+ return new InitializedHnswGraphBuilder(
+ scorerSupplier,
+ M,
+ beamWidth,
+ seed,
+ initGraph(M, initializerGraph, newOrdMap, totalNumberOfVectors),
+ initializedNodes);
+ }
+
+ public static OnHeapHnswGraph initGraph(
+ int M, HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors)
+ throws IOException {
OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
@@ -62,6 +74,7 @@
int oldOrd = it.nextInt();
int newOrd = newOrdMap[oldOrd];
hnsw.addNode(level, newOrd);
+ hnsw.trySetNewEntryNode(newOrd, level);
NeighborArray newNeighbors = hnsw.getNeighbors(level, newOrd);
initializerGraph.seek(level, oldOrd);
for (int oldNeighbor = initializerGraph.nextNeighbor();
@@ -73,8 +86,7 @@
}
}
}
- return new InitializedHnswGraphBuilder(
- scorerSupplier, M, beamWidth, seed, hnsw, initializedNodes);
+ return hnsw;
}
private final BitSet initializedNodes;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index f6cd54f..086a5ad 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -19,6 +19,8 @@
import java.io.IOException;
import java.util.Arrays;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.lucene.util.ArrayUtil;
/**
@@ -35,6 +37,7 @@
float[] score;
int[] node;
private int sortedNodeSize;
+ public final ReadWriteLock rwlock = new ReentrantReadWriteLock(true);
public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index bcb78e8..cfbfe5b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -21,6 +21,8 @@
import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
@@ -33,8 +35,7 @@
private static final int INIT_SIZE = 128;
- private int numLevels; // the current number of levels in the graph
- private int entryNode; // the current graph entry node on the top level. -1 if not set
+ private final AtomicReference<EntryNode> entryNode;
// the internal graph representation where the first dimension is node id and second dimension is
// level
@@ -47,11 +48,13 @@
private int
lastFreezeSize; // remember the size we are at last time to freeze the graph and generate
// levelToNodes
- private int size; // graph size, which is number of nodes in level 0
- private int
- nonZeroLevelSize; // total number of NeighborArrays created that is not on level 0, for now it
+ private final AtomicInteger size =
+ new AtomicInteger(0); // graph size, which is number of nodes in level 0
+ private final AtomicInteger nonZeroLevelSize =
+ new AtomicInteger(
+ 0); // total number of NeighborArrays created that is not on level 0, for now it
// is only used to account memory usage
- private int maxNodeId;
+ private final AtomicInteger maxNodeId = new AtomicInteger(-1);
private final int nsize; // neighbour array size at non-zero level
private final int nsize0; // neighbour array size at zero level
private final boolean
@@ -69,11 +72,9 @@
* growing itself (you cannot add a node with has id >= numNodes)
*/
OnHeapHnswGraph(int M, int numNodes) {
- this.numLevels = 1; // Implicitly start the graph with a single level
- this.entryNode = -1; // Entry node should be negative until a node is added
+ this.entryNode = new AtomicReference<>(new EntryNode(-1, 1));
// Neighbours' size on upper levels (nsize) and level 0 (nsize0)
// We allocate extra space for neighbours, but then prune them to keep allowed maximum
- this.maxNodeId = -1;
this.nsize = M + 1;
this.nsize0 = (M * 2 + 1);
noGrowth = numNodes != -1;
@@ -96,7 +97,7 @@
@Override
public int size() {
- return size;
+ return size.get();
}
/**
@@ -107,7 +108,16 @@
*/
@Override
public int maxNodeId() {
- return maxNodeId;
+ if (noGrowth) {
+ // we know the eventual graph size and the graph can possibly
+ // being concurrently modified
+ return graph.length - 1;
+ } else {
+ // The graph cannot be concurrently modified (and searched) if
+ // we don't know the size beforehand, so it's safe to return the
+ // actual maxNodeId
+ return maxNodeId.get();
+ }
}
/**
@@ -120,9 +130,6 @@
* @param node the node to add, represented as an ordinal on the level 0.
*/
public void addNode(int level, int node) {
- if (entryNode == -1) {
- entryNode = node;
- }
if (node >= graph.length) {
if (noGrowth) {
@@ -132,25 +139,20 @@
graph = ArrayUtil.grow(graph, node + 1);
}
- if (level >= numLevels) {
- numLevels = level + 1;
- entryNode = node;
- }
-
assert graph[node] == null || graph[node].length > level
: "node must be inserted from the top level";
if (graph[node] == null) {
graph[node] =
new NeighborArray[level + 1]; // assumption: we always call this function from top level
- size++;
+ size.incrementAndGet();
}
if (level == 0) {
graph[node][level] = new NeighborArray(nsize0, true);
} else {
graph[node][level] = new NeighborArray(nsize, true);
- nonZeroLevelSize++;
+ nonZeroLevelSize.incrementAndGet();
}
- maxNodeId = Math.max(maxNodeId, node);
+ maxNodeId.accumulateAndGet(node, Math::max);
}
@Override
@@ -174,7 +176,7 @@
*/
@Override
public int numLevels() {
- return numLevels;
+ return entryNode.get().level + 1;
}
/**
@@ -185,7 +187,41 @@
*/
@Override
public int entryNode() {
- return entryNode;
+ return entryNode.get().node;
+ }
+
+ /**
+ * Try to set the entry node if the graph does not have one
+ *
+ * @return True if the entry node is set to the provided node. False if the entry node already
+ * exists
+ */
+ public boolean trySetNewEntryNode(int node, int level) {
+ EntryNode current = entryNode.get();
+ if (current.node == -1) {
+ return entryNode.compareAndSet(current, new EntryNode(node, level));
+ }
+ return false;
+ }
+
+ /**
+ * Try to promote the provided node to the entry node
+ *
+ * @param level should be larger than expectedOldLevel
+ * @param expectOldLevel is the old entry node level the caller expect to be, the actual graph
+ * level can be different due to concurrent modification
+ * @return True if the entry node is set to the provided node. False if expectOldLevel is not the
+ * same as the current entry node level. Even if the provided node's level is still higher
+ * than the current entry node level, the new entry node will not be set and false will be
+ * returned.
+ */
+ public boolean tryPromoteNewEntryNode(int node, int level, int expectOldLevel) {
+ assert level > expectOldLevel;
+ EntryNode currentEntry = entryNode.get();
+ if (currentEntry.level == expectOldLevel) {
+ return entryNode.compareAndSet(currentEntry, new EntryNode(node, level));
+ }
+ return false;
}
/**
@@ -212,12 +248,12 @@
@SuppressWarnings({"unchecked", "rawtypes"})
private void generateLevelToNodes() {
- if (lastFreezeSize == size) {
+ if (lastFreezeSize == size()) {
return;
}
-
- levelToNodes = new List[numLevels];
- for (int i = 1; i < numLevels; i++) {
+ int maxLevels = numLevels();
+ levelToNodes = new List[maxLevels];
+ for (int i = 1; i < maxLevels; i++) {
levelToNodes[i] = new ArrayList<>();
}
int nonNullNode = 0;
@@ -230,38 +266,44 @@
for (int i = 1; i < graph[node].length; i++) {
levelToNodes[i].add(node);
}
- if (nonNullNode == size) {
+ if (nonNullNode == size()) {
break;
}
}
- lastFreezeSize = size;
+ lastFreezeSize = size();
}
@Override
public long ramBytesUsed() {
long neighborArrayBytes0 =
(long) nsize0 * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long neighborArrayBytes =
(long) nsize * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2L
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2L
+ Integer.BYTES * 3;
long total = 0;
total +=
- size * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ size() * (neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // for graph and level 0;
- total += nonZeroLevelSize * neighborArrayBytes; // for non-zero level
- total += 8 * Integer.BYTES; // all int fields
+ total += nonZeroLevelSize.get() * neighborArrayBytes; // for non-zero level
+ total += 4 * Integer.BYTES; // all int fields
+ total += 1; // field: noGrowth
+ total +=
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ + 2 * Integer.BYTES; // field: entryNode
+ total += 3L * (Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER); // 3 AtomicInteger
total += RamUsageEstimator.NUM_BYTES_OBJECT_REF; // field: cur
total += RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; // field: levelToNodes
if (levelToNodes != null) {
total +=
- (long) (numLevels - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
+ (long) (numLevels() - 1) * RamUsageEstimator.NUM_BYTES_OBJECT_REF; // no cost for level 0
total +=
- (long) nonZeroLevelSize
+ (long) nonZeroLevelSize.get()
* (RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ Integer.BYTES);
@@ -274,9 +316,11 @@
return "OnHeapHnswGraph(size="
+ size()
+ ", numLevels="
- + numLevels
+ + numLevels()
+ ", entryNode="
- + entryNode
+ + entryNode()
+ ")";
}
+
+ private record EntryNode(int node, int level) {}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java
index a922e2f..1db50ee 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java
@@ -32,12 +32,14 @@
RandomVectorScorer scorer(int ord) throws IOException;
/**
- * Creates a {@link RandomVectorScorerSupplier} to compare float vectors.
- *
- * <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
- * using it after calling this function. If you plan to use it again outside the returned {@link
- * RandomVectorScorer}, think about passing a copied version ({@link
- * RandomAccessVectorValues#copy}).
+ * Make a copy of the supplier, which will copy the underlying vectorValues so the copy is safe to
+ * be used in other threads.
+ */
+ RandomVectorScorerSupplier copy() throws IOException;
+
+ /**
+ * Creates a {@link RandomVectorScorerSupplier} to compare float vectors. The vectorValues passed
+ * in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
@@ -48,21 +50,12 @@
throws IOException {
// We copy the provided random accessor just once during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
- final RandomAccessVectorValues<float[]> vectorsCopy = vectors.copy();
- return queryOrd ->
- (RandomVectorScorer)
- cand ->
- similarityFunction.compare(
- vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
+ return new FloatScoringSupplier(vectors, similarityFunction);
}
/**
- * Creates a {@link RandomVectorScorerSupplier} to compare byte vectors.
- *
- * <p>WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid
- * using it after calling this function. If you plan to use it again outside the returned {@link
- * RandomVectorScorer}, think about passing a copied version ({@link
- * RandomAccessVectorValues#copy}).
+ * Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. The vectorValues passed
+ * in will be copied and the original copy will not be used.
*
* @param vectors the underlying storage for vectors
* @param similarityFunction the similarity function to score vectors
@@ -71,13 +64,64 @@
final RandomAccessVectorValues<byte[]> vectors,
final VectorSimilarityFunction similarityFunction)
throws IOException {
- // We copy the provided random accessor just once during the supplier's initialization
+ // We copy the provided random accessor only during the supplier's initialization
// and then reuse it consistently across all scorers for conducting vector comparisons.
- final RandomAccessVectorValues<byte[]> vectorsCopy = vectors.copy();
- return queryOrd ->
- (RandomVectorScorer)
- cand ->
- similarityFunction.compare(
- vectors.vectorValue(queryOrd), vectorsCopy.vectorValue(cand));
+ return new ByteScoringSupplier(vectors, similarityFunction);
+ }
+
+ /** RandomVectorScorerSupplier for bytes vector */
+ final class ByteScoringSupplier implements RandomVectorScorerSupplier {
+ private final RandomAccessVectorValues<byte[]> vectors;
+ private final RandomAccessVectorValues<byte[]> vectors1;
+ private final RandomAccessVectorValues<byte[]> vectors2;
+ private final VectorSimilarityFunction similarityFunction;
+
+ private ByteScoringSupplier(
+ RandomAccessVectorValues<byte[]> vectors, VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ this.vectors = vectors;
+ vectors1 = vectors.copy();
+ vectors2 = vectors.copy();
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ return cand ->
+ similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new ByteScoringSupplier(vectors, similarityFunction);
+ }
+ }
+
+ /** RandomVectorScorerSupplier for Float vector */
+ final class FloatScoringSupplier implements RandomVectorScorerSupplier {
+ private final RandomAccessVectorValues<float[]> vectors;
+ private final RandomAccessVectorValues<float[]> vectors1;
+ private final RandomAccessVectorValues<float[]> vectors2;
+ private final VectorSimilarityFunction similarityFunction;
+
+ private FloatScoringSupplier(
+ RandomAccessVectorValues<float[]> vectors, VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ this.vectors = vectors;
+ vectors1 = vectors.copy();
+ vectors2 = vectors.copy();
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ return cand ->
+ similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new FloatScoringSupplier(vectors, similarityFunction);
+ }
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java
index aa9c9f1..8f2fdd2 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java
@@ -33,7 +33,7 @@
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
- return new Lucene99HnswVectorsFormat(10, 20, null);
+ return new Lucene99HnswVectorsFormat(10, 20);
}
};
String expectedString =
@@ -42,13 +42,11 @@
}
public void testLimits() {
- expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20, null));
- expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20, null));
- expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0, null));
- expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1, null));
- expectThrows(
- IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20, null));
- expectThrows(
- IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201, null));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20));
+ expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201));
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
index c578ae8..c51f51f 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
@@ -170,8 +170,8 @@
try (Directory directory = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
KnnVectorsFormat format1 =
- new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100, null));
- KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100, null);
+ new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100));
+ KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100);
iwc.setCodec(
new AssertingCodec() {
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index 368f6f9..32f3259 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -113,8 +113,7 @@
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene99HnswVectorsFormat(
- M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, null);
+ return new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSubScorerFreqs.java b/lucene/core/src/test/org/apache/lucene/search/TestSubScorerFreqs.java
index d3db022..cc2cf69 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestSubScorerFreqs.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSubScorerFreqs.java
@@ -36,6 +36,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.AssertingScorable;
+import org.apache.lucene.tests.search.DisablingBulkScorerQuery;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -203,7 +204,8 @@
for (final Set<String> occur : occurList) {
Map<Integer, Map<Query, Float>> docCounts =
- s.search(query.build(), new CountingCollectorManager(occur));
+ s.search(
+ new DisablingBulkScorerQuery(query.build()), new CountingCollectorManager(occur));
final int maxDocs = s.getIndexReader().maxDoc();
assertEquals(maxDocs, docCounts.size());
boolean includeOptional = occur.contains("SHOULD");
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index 358db95..3153f39 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -17,6 +17,7 @@
package org.apache.lucene.util;
import java.util.Random;
+import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -115,6 +116,19 @@
expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v));
}
+ public void testExtremeNumerics() {
+ float[] v1 = new float[1536];
+ float[] v2 = new float[1536];
+ for (int i = 0; i < 1536; i++) {
+ v1[i] = 0.888888f;
+ v2[i] = -0.777777f;
+ }
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ float v = vectorSimilarityFunction.compare(v1, v2);
+ assertTrue(vectorSimilarityFunction + " expected >=0 got:" + v, v >= 0);
+ }
+ }
+
private static float l2(float[] v) {
float l2 = 0;
for (float x : v) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestAutomaton.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestAutomaton.java
index 6abe0aa..e4dd739 100644
--- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestAutomaton.java
+++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestAutomaton.java
@@ -790,9 +790,18 @@
return null;
}
+ private static boolean hasMassiveTerm(Collection<BytesRef> terms) {
+ for (BytesRef term : terms) {
+ if (term.length > Automata.MAX_STRING_UNION_TERM_LENGTH) {
+ return true;
+ }
+ }
+ return false;
+ }
+
private Automaton unionTerms(Collection<BytesRef> terms) {
Automaton a;
- if (random().nextBoolean()) {
+ if (random().nextBoolean() || hasMassiveTerm(terms)) {
if (VERBOSE) {
System.out.println("TEST: unionTerms: use union");
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestStringsToAutomaton.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestStringsToAutomaton.java
index f093623..0e5a3f9 100644
--- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestStringsToAutomaton.java
+++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestStringsToAutomaton.java
@@ -103,7 +103,7 @@
.startsWith(
"This builder doesn't allow terms that are larger than "
+ Automata.MAX_STRING_UNION_TERM_LENGTH
- + " characters"));
+ + " UTF-8 bytes"));
byte[] b1k = ArrayUtil.copyOfSubArray(b10k, 0, 1000);
build(Collections.singleton(new BytesRef(b1k)), false); // no exception
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index a6a4259..22670e1 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -165,7 +165,7 @@
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene99HnswVectorsFormat(M, beamWidth, null);
+ return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@@ -237,7 +237,7 @@
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene99HnswVectorsFormat(M, beamWidth, null);
+ return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@@ -298,7 +298,7 @@
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene99HnswVectorsFormat(M, beamWidth, null);
+ return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@@ -312,7 +312,7 @@
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene99HnswVectorsFormat(M, beamWidth, null);
+ return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@@ -565,6 +565,14 @@
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
RandomVectorScorerSupplier finalscorerSupplier = buildScorerSupplier(finalVectorValues);
+
+ // we cannot call getNodesOnLevel before the graph reaches the size it claimed, so here we
+ // create
+ // another graph to do the assertion
+ OnHeapHnswGraph graphAfterInit =
+ InitializedHnswGraphBuilder.initGraph(
+ 10, initializerGraph, initializerOrdMap, initializerGraph.size());
+
HnswGraphBuilder finalBuilder =
InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier,
@@ -578,7 +586,7 @@
totalSize);
// When offset is 0, the graphs should be identical before vectors are added
- assertGraphEqual(initializerGraph, finalBuilder.getGraph());
+ assertGraphEqual(initializerGraph, graphAfterInit);
OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.size());
assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
@@ -989,6 +997,33 @@
}
}
+ /*
+ * A very basic test ensure the concurrent merge does not throw exceptions, it by no means guarantees the
+ * true correctness of the concurrent merge and that must be checked manually by running a KNN benchmark
+ * and comparing the recall
+ */
+ public void testConcurrentMergeBuilder() throws IOException {
+ int size = atLeast(1000);
+ int dim = atLeast(10);
+ AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
+ RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
+ ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
+ HnswGraphBuilder.randSeed = random().nextLong();
+ HnswConcurrentMergeBuilder builder =
+ new HnswConcurrentMergeBuilder(
+ exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
+ builder.setBatchSize(100);
+ builder.build(size);
+ exec.shutdownNow();
+ OnHeapHnswGraph graph = builder.getGraph();
+ assertTrue(graph.entryNode() != -1);
+ assertEquals(size, graph.size());
+ assertEquals(size - 1, graph.maxNodeId());
+ for (int l = 0; l < graph.numLevels(); l++) {
+ assertNotNull(graph.getNodesOnLevel(l));
+ }
+ }
+
private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java
index 23c3511..f5d2a8a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestOnHeapHnswGraph.java
@@ -43,7 +43,7 @@
/* assert exception will be thrown when we call getNodeOnLevel for an incomplete graph */
public void testIncompleteGraphThrow() {
- OnHeapHnswGraph graph = new OnHeapHnswGraph(10, 10);
+ OnHeapHnswGraph graph = new OnHeapHnswGraph(10, -1);
graph.addNode(1, 0);
graph.addNode(0, 0);
assertEquals(1, graph.getNodesOnLevel(1).size());
@@ -62,6 +62,10 @@
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
+ graph.trySetNewEntryNode(i, l);
+ if (l > graph.numLevels() - 1) {
+ graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
+ }
levelToNodes.get(l).add(i);
}
}
@@ -93,6 +97,10 @@
int level = random().nextInt(maxLevel);
for (int l = level; l >= 0; l--) {
graph.addNode(l, i);
+ graph.trySetNewEntryNode(i, l);
+ if (l > graph.numLevels() - 1) {
+ graph.tryPromoteNewEntryNode(i, l, graph.numLevels() - 1);
+ }
levelToNodes.get(l).add(i);
}
}
diff --git a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java
index 102360f..3aa20b5 100644
--- a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java
+++ b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java
@@ -157,19 +157,9 @@
}
public void testSpanNearQueryWithoutSlopXML() throws Exception {
- Exception expectedException = new NumberFormatException("For input string: \"\"");
- try {
- Query q = parse("SpanNearQueryWithoutSlop.xml");
- fail("got query " + q + " instead of expected exception " + expectedException);
- } catch (Exception e) {
- assertEquals(expectedException.toString(), e.toString());
- }
- try {
- SpanQuery sq = parseAsSpan("SpanNearQueryWithoutSlop.xml");
- fail("got span query " + sq + " instead of expected exception " + expectedException);
- } catch (Exception e) {
- assertEquals(expectedException.toString(), e.toString());
- }
+ // expected NumberFormatException from empty "slop" string
+ assertThrows(NumberFormatException.class, () -> parse("SpanNearQueryWithoutSlop.xml"));
+ assertThrows(NumberFormatException.class, () -> parseAsSpan("SpanNearQueryWithoutSlop.xml"));
}
public void testConstantScoreQueryXML() throws Exception {
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/vector/ConfigurableMCodec.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/vector/ConfigurableMCodec.java
index 4cd473c..11aa4ac 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/vector/ConfigurableMCodec.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/vector/ConfigurableMCodec.java
@@ -32,12 +32,12 @@
public ConfigurableMCodec() {
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
- knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100, null);
+ knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100);
}
public ConfigurableMCodec(int maxConn) {
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
- knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100, null);
+ knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100);
}
@Override
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/DisablingBulkScorerQuery.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DisablingBulkScorerQuery.java
new file mode 100644
index 0000000..e6cee51
--- /dev/null
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DisablingBulkScorerQuery.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.tests.search;
+
+import java.io.IOException;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.FilterWeight;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+
+/** A {@link Query} wrapper that disables bulk-scoring optimizations. */
+public class DisablingBulkScorerQuery extends Query {
+
+ private final Query query;
+
+ /** Sole constructor. */
+ public DisablingBulkScorerQuery(Query query) {
+ this.query = query;
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query rewritten = query.rewrite(indexSearcher);
+ if (query != rewritten) {
+ return new DisablingBulkScorerQuery(rewritten);
+ }
+ return super.rewrite(indexSearcher);
+ }
+
+ @Override
+ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+ throws IOException {
+ Weight in = query.createWeight(searcher, scoreMode, boost);
+ return new FilterWeight(in) {
+ @Override
+ public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
+ Scorer scorer = scorer(context);
+ if (scorer == null) {
+ return null;
+ }
+ return new DefaultBulkScorer(scorer);
+ }
+ };
+ }
+
+ @Override
+ public String toString(String field) {
+ return query.toString(field);
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ query.visit(visitor);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return sameClassAs(obj) && query.equals(((DisablingBulkScorerQuery) obj).query);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * classHash() + query.hashCode();
+ }
+}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/RamUsageTester.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/RamUsageTester.java
index 7a234ca..91d8f0c 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/RamUsageTester.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/RamUsageTester.java
@@ -39,6 +39,8 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;
@@ -159,7 +161,11 @@
// Ignore JDK objects we can't access or handle properly.
Predicate<Object> isIgnorable =
- (clazz) -> (clazz instanceof CharsetEncoder) || (clazz instanceof CharsetDecoder);
+ (clazz) ->
+ (clazz instanceof CharsetEncoder)
+ || (clazz instanceof CharsetDecoder)
+ || (clazz instanceof ReentrantReadWriteLock)
+ || (clazz instanceof AtomicReference<?>);
if (isIgnorable.test(ob)) {
return accumulator.accumulateObject(ob, 0, Collections.emptyMap(), stack);
}