Ensure negative scores aren not returned from scalar quantization scorer (#13356)
Depending on how we quantize and then scale, we can edge down below 0 for dotproduct scores.
This is exceptionally rare, I have only seen it in extreme circumstances in tests (with random data and low dimensionality).
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 5d3adf5..b44c562 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -361,6 +361,8 @@
* GITHUB#12966: Aggregation facets no longer assume that aggregation values are positive. (Stefan Vodita)
+* GITHUB#13356: Ensure negative scores are not returned from scalar quantization scorer. (Ben Trent)
+
Build
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java
index edb3559..96c9358 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java
@@ -100,11 +100,10 @@
return switch (sim) {
case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes);
case COSINE, DOT_PRODUCT -> dotProductFactory(
- targetBytes, offsetCorrection, sim, constMultiplier, values, f -> (1 + f) / 2);
+ targetBytes, offsetCorrection, constMultiplier, values, f -> Math.max((1 + f) / 2, 0));
case MAXIMUM_INNER_PRODUCT -> dotProductFactory(
targetBytes,
offsetCorrection,
- sim,
constMultiplier,
values,
VectorUtil::scaleMaxInnerProductScore);
@@ -114,7 +113,6 @@
private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(
byte[] targetBytes,
float offsetCorrection,
- VectorSimilarityFunction sim,
float constMultiplier,
RandomAccessQuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
@@ -179,6 +177,8 @@
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes);
+ // For the current implementation of scalar quantization, all dotproducts should be >= 0;
+ assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
@@ -216,6 +216,8 @@
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
+ // For the current implementation of scalar quantization, all dotproducts should be >= 0;
+ assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
@@ -247,6 +249,8 @@
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes);
+ // For the current implementation of scalar quantization, all dotproducts should be >= 0;
+ assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java
index 77ad410..6c11ef7 100644
--- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java
+++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java
@@ -80,8 +80,10 @@
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
+ // For the current implementation of scalar quantization, all dotproducts should be >= 0;
+ assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
- return (1 + adjustedDistance) / 2;
+ return Math.max((1 + adjustedDistance) / 2, 0);
}
}
@@ -99,6 +101,8 @@
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
+ // For the current implementation of scalar quantization, all dotproducts should be >= 0;
+ assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return scaleMaxInnerProductScore(adjustedDistance);
}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
index fb8ffe3..ca815aa 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
@@ -36,6 +36,9 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.SameThreadExecutorService;
@@ -77,6 +80,41 @@
};
}
+ public void testQuantizationScoringEdgeCase() throws Exception {
+ float[][] vectors = new float[][] {{0.6f, 0.8f}, {0.8f, 0.6f}, {-0.6f, -0.8f}};
+ try (Directory dir = newDirectory();
+ IndexWriter w =
+ new IndexWriter(
+ dir,
+ newIndexWriterConfig()
+ .setCodec(
+ new Lucene99Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene99HnswScalarQuantizedVectorsFormat(
+ 16, 100, 1, (byte) 7, false, 0.9f, null);
+ }
+ }))) {
+ for (float[] vector : vectors) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.DOT_PRODUCT));
+ w.addDocument(doc);
+ w.commit();
+ }
+ w.forceMerge(1);
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE);
+ r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null);
+ TopDocs topDocs = topKnnCollector.topDocs();
+ assertEquals(3, topDocs.totalHits.value);
+ for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+ assertTrue(scoreDoc.score >= 0f);
+ }
+ }
+ }
+ }
+
public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java
index f926d0d..cbfecde 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java
@@ -17,7 +17,12 @@
package org.apache.lucene.codecs.lucene99;
+import static org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues.compressBytes;
+
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
@@ -32,9 +37,14 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
+import org.apache.lucene.util.quantization.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
@@ -54,6 +64,95 @@
};
}
+ public void testNonZeroScores() throws IOException {
+ for (int bits : new int[] {4, 7}) {
+ for (boolean compress : new boolean[] {true, false}) {
+ vectorNonZeroScoringTest(bits, compress);
+ }
+ }
+ }
+
+ private void vectorNonZeroScoringTest(int bits, boolean compress) throws IOException {
+ try (Directory dir = newDirectory()) {
+ // keep vecs `0` so dot product is `0`
+ byte[] vec1 = new byte[32];
+ byte[] vec2 = new byte[32];
+ if (compress && bits == 4) {
+ byte[] vec1Compressed = new byte[16];
+ byte[] vec2Compressed = new byte[16];
+ compressBytes(vec1, vec1Compressed);
+ compressBytes(vec2, vec2Compressed);
+ vec1 = vec1Compressed;
+ vec2 = vec2Compressed;
+ }
+ String fileName = getTestName() + "-32";
+ try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
+ // large negative offset to override any query score correction and
+ // ensure negative values that need to be snapped to `0`
+ var negativeOffset = floatToByteArray(-50f);
+ byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
+ out.writeBytes(bytes, 0, bytes.length);
+ }
+ ScalarQuantizer scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) bits);
+ try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
+ Lucene99ScalarQuantizedVectorScorer scorer =
+ new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
+ RandomAccessQuantizedByteVectorValues values =
+ new RandomAccessQuantizedByteVectorValues() {
+ @Override
+ public int dimension() {
+ return 32;
+ }
+
+ @Override
+ public int getVectorByteLength() {
+ return compress && bits == 4 ? 16 : 32;
+ }
+
+ @Override
+ public int size() {
+ return 2;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) {
+ return new byte[32];
+ }
+
+ @Override
+ public float getScoreCorrectionConstant(int ord) {
+ return -50;
+ }
+
+ @Override
+ public RandomAccessQuantizedByteVectorValues copy() throws IOException {
+ return this;
+ }
+
+ @Override
+ public IndexInput getSlice() {
+ return in;
+ }
+
+ @Override
+ public ScalarQuantizer getScalarQuantizer() {
+ return scalarQuantizer;
+ }
+ };
+ float[] queryVector = new float[32];
+ for (int i = 0; i < 32; i++) {
+ queryVector[i] = i * 0.1f;
+ }
+ for (VectorSimilarityFunction function : VectorSimilarityFunction.values()) {
+ RandomVectorScorer randomScorer =
+ scorer.getRandomVectorScorer(function, values, queryVector);
+ assertTrue(randomScorer.score(0) >= 0f);
+ assertTrue(randomScorer.score(1) >= 0f);
+ }
+ }
+ }
+ }
+
public void testScoringCompressedInt4() throws Exception {
vectorScoringTest(4, true);
}
@@ -152,4 +251,17 @@
writer.forceMerge(1);
}
}
+
+ private static byte[] floatToByteArray(float value) {
+ return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array();
+ }
+
+ private static byte[] concat(byte[]... arrays) throws IOException {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
+ for (var ba : arrays) {
+ baos.write(ba);
+ }
+ return baos.toByteArray();
+ }
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java
index 4680460..809cb8d 100644
--- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java
+++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java
@@ -30,6 +30,26 @@
public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
+ public void testNonZeroScores() {
+ byte[][] quantized = new byte[2][32];
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ float multiplier = random().nextFloat();
+ if (random().nextBoolean()) {
+ multiplier = -multiplier;
+ }
+ for (byte bits : new byte[] {4, 7}) {
+ ScalarQuantizedVectorSimilarity quantizedSimilarity =
+ ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
+ similarityFunction, multiplier, bits);
+ float negativeOffsetA = -(random().nextFloat() * (random().nextInt(10) + 1));
+ float negativeOffsetB = -(random().nextFloat() * (random().nextInt(10) + 1));
+ float score =
+ quantizedSimilarity.score(quantized[0], negativeOffsetA, quantized[1], negativeOffsetB);
+ assertTrue(score >= 0);
+ }
+ }
+ }
+
public void testToEuclidean() throws IOException {
int dims = 128;
int numVecs = 100;