blob: ae13be02467dbc0d5138612df03882c8cca2c702 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
* Writes quantized vector values and metadata to index segments.
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
private static final long SHALLOW_RAM_BYTES_USED =
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
// Used for determining when merged quantiles shifted too far from individual segment quantiles.
// When merging quantiles from various segments, we need to ensure that the new quantiles
// are not exceptionally different from an individual segments quantiles.
// This would imply that the quantization buckets would shift too much
// for floating point values and justify recalculating the quantiles. This helps preserve
// accuracy of the calculated quantiles, even in adversarial cases such as vector clustering.
// This number was determined via empirical testing
private static final float QUANTILE_RECOMPUTE_LIMIT = 32;
// Used for determining if a new quantization state requires a re-quantization
// for a given segment.
// This ensures that in expectation 4/5 of the vector would be unchanged by requantization.
// Furthermore, only those values where the value is within 1/5 of the centre of a quantization
// bin will be changed. In these cases the error introduced by snapping one way or another
// is small compared to the error introduced by quantization in the first place. Furthermore,
// empirical testing showed that the relative error by not requantizing is small (compared to
// the quantization error) and the condition is sensitive enough to detect all adversarial cases,
// such as merging clustered data.
private static final float REQUANTIZATION_LIMIT = 0.2f;
private final SegmentWriteState segmentWriteState;
private final List<FieldWriter> fields = new ArrayList<>();
private final IndexOutput meta, quantizedVectorData;
private final Float confidenceInterval;
private final FlatVectorsWriter rawVectorDelegate;
private final byte bits;
private final boolean compress;
private final int version;
private boolean finished;
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
Float confidenceInterval,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
this(
state,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
confidenceInterval,
(byte) 7,
false,
rawVectorDelegate,
scorer);
}
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
this(
state,
Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS,
confidenceInterval,
bits,
compress,
rawVectorDelegate,
scorer);
}
private Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
int version,
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
throws IOException {
super(scorer);
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.compress = compress;
this.version = version;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
String quantizedVectorDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
this.rawVectorDelegate = rawVectorDelegate;
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
quantizedVectorData =
state.directory.createOutput(quantizedVectorDataFileName, state.context);
CodecUtil.writeIndexHeader(
meta,
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
version,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
quantizedVectorData,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
version,
state.segmentInfo.getId(),
state.segmentSuffix);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) {
throw new IllegalArgumentException(
"bits="
+ bits
+ " is not supported for odd vector dimensions; vector dimension="
+ fieldInfo.getVectorDimension());
}
FieldWriter quantizedWriter =
new FieldWriter(
confidenceInterval,
bits,
compress,
fieldInfo,
segmentWriteState.infoStream,
indexWriter);
fields.add(quantizedWriter);
indexWriter = quantizedWriter;
}
return rawVectorDelegate.addField(fieldInfo, indexWriter);
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
// Since we know we will not be searching for additional indexing, we can just write the
// the vectors directly to the new segment.
// No need to use temporary file as we don't have to re-open for reading
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
ScalarQuantizer mergedQuantizationState =
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(quantizedVectorData, byteVectorValues, bits, compress);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
}
}
@Override
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
// Simply merge the underlying delegate, which just copies the raw vector data to a new
// segment file
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
ScalarQuantizer mergedQuantizationState =
mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
return mergeOneFieldToIndex(
segmentWriteState, fieldInfo, mergeState, mergedQuantizationState);
}
// We only merge the delegate, since the field type isn't float32, quantization wasn't
// supported, so bypass it.
return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter field : fields) {
field.finish();
if (sortMap == null) {
writeField(field, maxDoc);
} else {
writeSortingField(field, maxDoc, sortMap);
}
}
}
@Override
public void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
rawVectorDelegate.finish();
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (quantizedVectorData != null) {
CodecUtil.writeFooter(quantizedVectorData);
}
}
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
for (FieldWriter field : fields) {
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeQuantizedVectors(fieldData);
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
fieldData.docsWithField);
}
private void writeMeta(
FieldInfo field,
int maxDoc,
long vectorDataOffset,
long vectorDataLength,
Float confidenceInterval,
byte bits,
boolean compress,
Float lowerQuantile,
Float upperQuantile,
DocsWithFieldSet docsWithField)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVInt(field.getVectorDimension());
int count = docsWithField.cardinality();
meta.writeInt(count);
if (count > 0) {
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
if (version >= Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) {
meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval));
meta.writeByte(bits);
meta.writeByte(compress ? (byte) 1 : (byte) 0);
} else {
meta.writeInt(
Float.floatToIntBits(
confidenceInterval == null
? calculateDefaultConfidenceInterval(field.getVectorDimension())
: confidenceInterval));
}
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
}
// write docIDs
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
}
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
? OffHeapQuantizedByteVectorValues.compressedArray(
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (float[] v : fieldData.floatVectors) {
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
} else {
quantizedVectorData.writeBytes(vector, vector.length);
}
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
}
}
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeSortedQuantizedVectors(fieldData, ordMap);
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
writeMeta(
fieldData.fieldInfo,
maxDoc,
vectorDataOffset,
quantizedVectorLength,
confidenceInterval,
bits,
compress,
fieldData.minQuantile,
fieldData.maxQuantile,
newDocsWithField);
}
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
byte[] compressedVector =
fieldData.compress
? OffHeapQuantizedByteVectorValues.compressedArray(
fieldData.fieldInfo.getVectorDimension(), bits)
: null;
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal);
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector);
quantizedVectorData.writeBytes(compressedVector, compressedVector.length);
} else {
quantizedVectorData.writeBytes(vector, vector.length);
}
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
}
}
private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
SegmentWriteState segmentWriteState,
FieldInfo fieldInfo,
MergeState mergeState,
ScalarQuantizer mergedQuantizationState)
throws IOException {
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"quantized field="
+ " confidenceInterval="
+ confidenceInterval
+ " minQuantile="
+ mergedQuantizationState.getLowerQuantile()
+ " maxQuantile="
+ mergedQuantizationState.getUpperQuantile());
}
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
IndexOutput tempQuantizedVectorData =
segmentWriteState.directory.createTempOutput(
quantizedVectorData.getName(), "temp", segmentWriteState.context);
IndexInput quantizationDataInput = null;
boolean success = false;
try {
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues, bits, compress);
CodecUtil.writeFooter(tempQuantizedVectorData);
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
segmentWriteState.directory.openInput(
tempQuantizedVectorData.getName(), segmentWriteState.context);
quantizedVectorData.copyBytes(
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
CodecUtil.retrieveChecksum(quantizationDataInput);
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
vectorDataOffset,
vectorDataLength,
confidenceInterval,
bits,
compress,
mergedQuantizationState.getLowerQuantile(),
mergedQuantizationState.getUpperQuantile(),
docsWithField);
success = true;
final IndexInput finalQuantizationDataInput = quantizationDataInput;
return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
() -> {
IOUtils.close(finalQuantizationDataInput);
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
},
docsWithField.cardinality(),
vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
mergedQuantizationState,
compress,
fieldInfo.getVectorSimilarityFunction(),
vectorsScorer,
quantizationDataInput)));
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(tempQuantizedVectorData, quantizationDataInput);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempQuantizedVectorData.getName());
}
}
}
static ScalarQuantizer mergeQuantiles(
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, byte bits) {
assert quantizationStates.size() == segmentSizes.size();
if (quantizationStates.isEmpty()) {
return null;
}
float lowerQuantile = 0f;
float upperQuantile = 0f;
int totalCount = 0;
for (int i = 0; i < quantizationStates.size(); i++) {
if (quantizationStates.get(i) == null) {
return null;
}
lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i);
upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i);
totalCount += segmentSizes.get(i);
if (quantizationStates.get(i).getBits() != bits) {
return null;
}
}
lowerQuantile /= totalCount;
upperQuantile /= totalCount;
return new ScalarQuantizer(lowerQuantile, upperQuantile, bits);
}
/**
* Returns true if the quantiles of the merged state are too far from the quantiles of the
* individual states.
*
* @param mergedQuantizationState The merged quantization state
* @param quantizationStates The quantization states of the individual segments
* @return true if the quantiles should be recomputed
*/
static boolean shouldRecomputeQuantiles(
ScalarQuantizer mergedQuantizationState, List<ScalarQuantizer> quantizationStates) {
// calculate the limit for the quantiles to be considered too far apart
// We utilize upper & lower here to determine if the new upper and merged upper would
// drastically
// change the quantization buckets for floats
// This is a fairly conservative check.
float limit =
(mergedQuantizationState.getUpperQuantile() - mergedQuantizationState.getLowerQuantile())
/ QUANTILE_RECOMPUTE_LIMIT;
for (ScalarQuantizer quantizationState : quantizationStates) {
if (Math.abs(
quantizationState.getUpperQuantile() - mergedQuantizationState.getUpperQuantile())
> limit) {
return true;
}
if (Math.abs(
quantizationState.getLowerQuantile() - mergedQuantizationState.getLowerQuantile())
> limit) {
return true;
}
}
return false;
}
private static QuantizedVectorsReader getQuantizedKnnVectorsReader(
KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof QuantizedVectorsReader reader) {
return reader;
}
return null;
}
private static ScalarQuantizer getQuantizedState(
KnnVectorsReader vectorsReader, String fieldName) {
QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(vectorsReader, fieldName);
if (reader != null) {
return reader.getQuantizationState(fieldName);
}
return null;
}
/**
* Merges the quantiles of the segments and recalculates the quantiles if necessary.
*
* @param mergeState The merge state
* @param fieldInfo The field info
* @param confidenceInterval The confidence interval
* @param bits The number of bits
* @return The merged quantiles
* @throws IOException If there is a low-level I/O error
*/
public static ScalarQuantizer mergeAndRecalculateQuantiles(
MergeState mergeState, FieldInfo fieldInfo, Float confidenceInterval, byte bits)
throws IOException {
assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
FloatVectorValues fvv;
if (mergeState.knnVectorsReaders[i] != null
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
&& fvv.size() > 0) {
ScalarQuantizer quantizationState =
getQuantizedState(mergeState.knnVectorsReaders[i], fieldInfo.name);
// If we have quantization state, we can utilize that to make merging cheaper
quantizationStates.add(quantizationState);
segmentSizes.add(fvv.size());
}
}
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, bits);
// Segments no providing quantization state indicates that their quantiles were never
// calculated.
// To be safe, we should always recalculate given a sample set over all the float vectors in the
// merged
// segment view
if (mergedQuantiles == null
// For smaller `bits` values, we should always recalculate the quantiles
// TODO: this is very conservative, could we reuse information for even int4 quantization?
|| bits <= 4
|| shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
int numVectors = 0;
FloatVectorValues vectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
// iterate vectorValues and increment numVectors
for (int doc = vectorValues.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = vectorValues.nextDoc()) {
numVectors++;
}
mergedQuantiles =
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
fieldInfo.getVectorSimilarityFunction(),
numVectors,
bits)
: ScalarQuantizer.fromVectors(
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState),
confidenceInterval,
numVectors,
bits);
}
return mergedQuantiles;
}
/**
* Returns true if the quantiles of the new quantization state are too far from the quantiles of
* the existing quantization state. This would imply that floating point values would slightly
* shift quantization buckets.
*
* @param existingQuantiles The existing quantiles for a segment
* @param newQuantiles The new quantiles for a segment, could be merged, or fully re-calculated
* @return true if the floating point values should be requantized
*/
static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) {
float tol =
REQUANTIZATION_LIMIT
* (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile())
/ 128f;
if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) {
return true;
}
return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol;
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
public static DocsWithFieldSet writeQuantizedVectorData(
IndexOutput output,
QuantizedByteVectorValues quantizedByteVectorValues,
byte bits,
boolean compress)
throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
final byte[] compressedVector =
compress
? OffHeapQuantizedByteVectorValues.compressedArray(
quantizedByteVectorValues.dimension(), bits)
: null;
for (int docV = quantizedByteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = quantizedByteVectorValues.nextDoc()) {
// write vector
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
assert binaryValue.length == quantizedByteVectorValues.dimension()
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
if (compressedVector != null) {
OffHeapQuantizedByteVectorValues.compressBytes(binaryValue, compressedVector);
output.writeBytes(compressedVector, compressedVector.length);
} else {
output.writeBytes(binaryValue, binaryValue.length);
}
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
docsWithField.add(docV);
}
return docsWithField;
}
@Override
public void close() throws IOException {
IOUtils.close(meta, quantizedVectorData, rawVectorDelegate);
}
static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
private final List<float[]> floatVectors;
private final FieldInfo fieldInfo;
private final Float confidenceInterval;
private final byte bits;
private final boolean compress;
private final InfoStream infoStream;
private final boolean normalize;
private float minQuantile = Float.POSITIVE_INFINITY;
private float maxQuantile = Float.NEGATIVE_INFINITY;
private boolean finished;
private final DocsWithFieldSet docsWithField;
@SuppressWarnings("unchecked")
FieldWriter(
Float confidenceInterval,
byte bits,
boolean compress,
FieldInfo fieldInfo,
InfoStream infoStream,
KnnFieldVectorsWriter<?> indexWriter) {
super((KnnFieldVectorsWriter<float[]>) indexWriter);
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.fieldInfo = fieldInfo;
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
this.floatVectors = new ArrayList<>();
this.infoStream = infoStream;
this.docsWithField = new DocsWithFieldSet();
this.compress = compress;
}
void finish() throws IOException {
if (finished) {
return;
}
if (floatVectors.size() == 0) {
finished = true;
return;
}
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
ScalarQuantizer quantizer =
confidenceInterval == null
? ScalarQuantizer.fromVectorsAutoInterval(
floatVectorValues,
fieldInfo.getVectorSimilarityFunction(),
floatVectors.size(),
bits)
: ScalarQuantizer.fromVectors(
floatVectorValues, confidenceInterval, floatVectors.size(), bits);
minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"quantized field="
+ " confidenceInterval="
+ confidenceInterval
+ " bits="
+ bits
+ " minQuantile="
+ minQuantile
+ " maxQuantile="
+ maxQuantile);
}
finished = true;
}
ScalarQuantizer createQuantizer() {
assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, bits);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (floatVectors.size() == 0) return size;
return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
}
@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
docsWithField.add(docID);
floatVectors.add(vectorValue);
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, vectorValue);
}
}
@Override
public float[] copyValue(float[] vectorValue) {
throw new UnsupportedOperationException();
}
}
static class FloatVectorWrapper extends FloatVectorValues {
private final List<float[]> vectorList;
private final float[] copy;
private final boolean normalize;
protected int curDoc = -1;
FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
this.vectorList = vectorList;
this.copy = new float[vectorList.get(0).length];
this.normalize = normalize;
}
@Override
public int dimension() {
return vectorList.get(0).length;
}
@Override
public int size() {
return vectorList.size();
}
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= vectorList.size()) {
throw new IOException("Current doc not set or too many iterations");
}
if (normalize) {
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
return copy;
}
return vectorList.get(curDoc);
}
@Override
public int docID() {
if (curDoc >= vectorList.size()) {
return NO_MORE_DOCS;
}
return curDoc;
}
@Override
public int nextDoc() throws IOException {
curDoc++;
return docID();
}
@Override
public int advance(int target) throws IOException {
curDoc = target;
return docID();
}
@Override
public VectorScorer scorer(float[] target) {
throw new UnsupportedOperationException();
}
}
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
private final QuantizedByteVectorValues values;
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
}
}
/** Returns a merged view over all the segment's {@link QuantizedByteVectorValues}. */
static class MergedQuantizedVectorValues extends QuantizedByteVectorValues {
public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues(
FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer scalarQuantizer)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
if (mergeState.knnVectorsReaders[i] != null
&& mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) {
QuantizedVectorsReader reader =
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
assert scalarQuantizer != null;
final QuantizedByteVectorValueSub sub;
// Either our quantization parameters are way different than the merged ones
// Or we have never been quantized.
if (reader == null
|| reader.getQuantizationState(fieldInfo.name) == null
// For smaller `bits` values, we should always recalculate the quantiles
// TODO: this is very conservative, could we reuse information for even int4
// quantization?
|| scalarQuantizer.getBits() <= 4
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
sub =
new QuantizedByteVectorValueSub(
mergeState.docMaps[i],
new QuantizedFloatVectorValues(
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name),
fieldInfo.getVectorSimilarityFunction(),
scalarQuantizer));
} else {
sub =
new QuantizedByteVectorValueSub(
mergeState.docMaps[i],
new OffsetCorrectedQuantizedByteVectorValues(
reader.getQuantizedVectorValues(fieldInfo.name),
fieldInfo.getVectorSimilarityFunction(),
scalarQuantizer,
reader.getQuantizationState(fieldInfo.name)));
}
subs.add(sub);
}
}
return new MergedQuantizedVectorValues(subs, mergeState);
}
private final List<QuantizedByteVectorValueSub> subs;
private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
private final int size;
private int docId;
private QuantizedByteVectorValueSub current;
private MergedQuantizedVectorValues(
List<QuantizedByteVectorValueSub> subs, MergeState mergeState) throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (QuantizedByteVectorValueSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public byte[] vectorValue() throws IOException {
return current.values.vectorValue();
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
}
return docId;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return subs.get(0).values.dimension();
}
@Override
public float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
private final float[] normalizedVector;
private float offsetValue = 0f;
private final VectorSimilarityFunction vectorSimilarityFunction;
public QuantizedFloatVectorValues(
FloatVectorValues values,
VectorSimilarityFunction vectorSimilarityFunction,
ScalarQuantizer quantizer) {
this.values = values;
this.quantizer = quantizer;
this.quantizedVector = new byte[values.dimension()];
this.vectorSimilarityFunction = vectorSimilarityFunction;
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
this.normalizedVector = new float[values.dimension()];
} else {
this.normalizedVector = null;
}
}
@Override
public float getScoreCorrectionConstant() {
return offsetValue;
}
@Override
public int dimension() {
return values.dimension();
}
@Override
public int size() {
return values.size();
}
@Override
public byte[] vectorValue() throws IOException {
return quantizedVector;
}
@Override
public int docID() {
return values.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = values.nextDoc();
if (doc != NO_MORE_DOCS) {
quantize();
}
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = values.advance(target);
if (doc != NO_MORE_DOCS) {
quantize();
}
return doc;
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
offsetValue =
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
} else {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
}
}
static final class ScalarQuantizedCloseableRandomVectorScorerSupplier
implements CloseableRandomVectorScorerSupplier {
private final RandomVectorScorerSupplier supplier;
private final Closeable onClose;
private final int numVectors;
ScalarQuantizedCloseableRandomVectorScorerSupplier(
Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) {
this.onClose = onClose;
this.supplier = supplier;
this.numVectors = numVectors;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return supplier.copy();
}
@Override
public void close() throws IOException {
onClose.close();
}
@Override
public int totalVectorCount() {
return numVectors;
}
}
static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues {
private final QuantizedByteVectorValues in;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
OffsetCorrectedQuantizedByteVectorValues(
QuantizedByteVectorValues in,
VectorSimilarityFunction vectorSimilarityFunction,
ScalarQuantizer scalarQuantizer,
ScalarQuantizer oldScalarQuantizer) {
this.in = in;
this.vectorSimilarityFunction = vectorSimilarityFunction;
this.scalarQuantizer = scalarQuantizer;
this.oldScalarQuantizer = oldScalarQuantizer;
}
@Override
public float getScoreCorrectionConstant() throws IOException {
return scalarQuantizer.recalculateCorrectiveOffset(
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
}
@Override
public int dimension() {
return in.dimension();
}
@Override
public int size() {
return in.size();
}
@Override
public byte[] vectorValue() throws IOException {
return in.vectorValue();
}
@Override
public int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
return in.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return in.advance(target);
}
@Override
public VectorScorer scorer(float[] target) throws IOException {
throw new UnsupportedOperationException();
}
}
}