blob: 5a20770bb9884e6791e02a48f7c264c4aaee84ff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;
/**
* Reads Scalar Quantized vectors from the index segments along with index data structures.
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReader
implements QuantizedVectorsReader {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;
public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
throws IOException {
super(scorer);
this.rawVectorsReader = rawVectorsReader;
int versionMeta = -1;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
boolean success = false;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
Throwable priorE = null;
try {
versionMeta =
CodecUtil.checkIndexHeader(
meta,
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta, versionMeta, state.fieldInfos);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(meta, priorE);
}
quantizedVectorData =
openDataInput(
state,
versionMeta,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
// Quantized vectors are accessed randomly from their node ID stored in the HNSW
// graph.
state.context.withReadAdvice(ReadAdvice.RANDOM));
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
private void readFields(ChecksumIndexInput meta, int versionMeta, FieldInfos infos)
throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
FieldInfo info = infos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = readField(meta, versionMeta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
}
}
static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
int dimension = info.getVectorDimension();
if (dimension != fieldEntry.dimension) {
throw new IllegalStateException(
"Inconsistent vector dimension for field=\""
+ info.name
+ "\"; "
+ dimension
+ " != "
+ fieldEntry.dimension);
}
final long quantizedVectorBytes;
if (fieldEntry.bits <= 4 && fieldEntry.compress) {
quantizedVectorBytes = ((dimension + 1) >> 1) + Float.BYTES;
} else {
// int8 quantized and calculated stored offset.
quantizedVectorBytes = dimension + Float.BYTES;
}
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size);
if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
"Quantized vector data length "
+ fieldEntry.vectorDataLength
+ " not matching size="
+ fieldEntry.size
+ " * (dim="
+ dimension
+ " + 4)"
+ " = "
+ numQuantizedVectorBytes);
}
}
@Override
public void checkIntegrity() throws IOException {
rawVectorsReader.checkIntegrity();
CodecUtil.checksumEntireFile(quantizedVectorData);
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
final FloatVectorValues rawVectorValues = rawVectorsReader.getFloatVectorValues(field);
OffHeapQuantizedByteVectorValues quantizedByteVectorValues =
OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
return new QuantizedVectorValues(rawVectorValues, quantizedByteVectorValues);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return rawVectorsReader.getByteVectorValues(field);
}
private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,
String fileExtension,
String codecName,
IOContext context)
throws IOException {
String fileName =
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, context);
boolean success = false;
try {
int versionVectorData =
CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta="
+ versionMeta
+ ", "
+ codecName
+ "="
+ versionVectorData,
in);
}
CodecUtil.retrieveChecksum(in);
success = true;
return in;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(in);
}
}
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
if (fieldEntry.scalarQuantizer == null) {
return rawVectorsReader.getRandomVectorScorer(field, target);
}
OffHeapQuantizedByteVectorValues vectorValues =
OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
return vectorScorer.getRandomVectorScorer(fieldEntry.similarityFunction, vectorValues, target);
}
@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
return rawVectorsReader.getRandomVectorScorer(field, target);
}
@Override
public void close() throws IOException {
IOUtils.close(quantizedVectorData, rawVectorsReader);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size +=
RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
size += rawVectorsReader.ramBytesUsed();
return size;
}
private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info)
throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction());
}
return FieldEntry.create(
input, versionMeta, vectorEncoding, info.getVectorSimilarityFunction());
}
@Override
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return OffHeapQuantizedByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.scalarQuantizer,
fieldEntry.similarityFunction,
vectorScorer,
fieldEntry.compress,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
quantizedVectorData);
}
@Override
public ScalarQuantizer getQuantizationState(String fieldName) {
FieldEntry fieldEntry = fields.get(fieldName);
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return null;
}
return fieldEntry.scalarQuantizer;
}
private record FieldEntry(
VectorSimilarityFunction similarityFunction,
VectorEncoding vectorEncoding,
int dimension,
long vectorDataOffset,
long vectorDataLength,
ScalarQuantizer scalarQuantizer,
int size,
byte bits,
boolean compress,
OrdToDocDISIReaderConfiguration ordToDoc) {
static FieldEntry create(
IndexInput input,
int versionMeta,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
final var vectorDataOffset = input.readVLong();
final var vectorDataLength = input.readVLong();
final var dimension = input.readVInt();
final var size = input.readInt();
final ScalarQuantizer scalarQuantizer;
final byte bits;
final boolean compress;
if (size > 0) {
if (versionMeta < Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) {
int floatBits = input.readInt(); // confidenceInterval, unused
if (floatBits == -1) {
throw new CorruptIndexException(
"Missing confidence interval for scalar quantizer", input);
}
bits = (byte) 7;
compress = false;
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, (byte) 7);
} else {
input.readInt(); // confidenceInterval, unused
bits = input.readByte();
compress = input.readByte() == 1;
float minQuantile = Float.intBitsToFloat(input.readInt());
float maxQuantile = Float.intBitsToFloat(input.readInt());
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, bits);
}
} else {
scalarQuantizer = null;
bits = (byte) 7;
compress = false;
}
final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
return new FieldEntry(
similarityFunction,
vectorEncoding,
dimension,
vectorDataOffset,
vectorDataLength,
scalarQuantizer,
size,
bits,
compress,
ordToDoc);
}
}
private static final class QuantizedVectorValues extends FloatVectorValues {
private final FloatVectorValues rawVectorValues;
private final OffHeapQuantizedByteVectorValues quantizedVectorValues;
QuantizedVectorValues(
FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) {
this.rawVectorValues = rawVectorValues;
this.quantizedVectorValues = quantizedVectorValues;
}
@Override
public int dimension() {
return rawVectorValues.dimension();
}
@Override
public int size() {
return rawVectorValues.size();
}
@Override
public float[] vectorValue() throws IOException {
return rawVectorValues.vectorValue();
}
@Override
public int docID() {
return rawVectorValues.docID();
}
@Override
public int nextDoc() throws IOException {
int rawDocId = rawVectorValues.nextDoc();
int quantizedDocId = quantizedVectorValues.nextDoc();
assert rawDocId == quantizedDocId;
return quantizedDocId;
}
@Override
public int advance(int target) throws IOException {
int rawDocId = rawVectorValues.advance(target);
int quantizedDocId = quantizedVectorValues.advance(target);
assert rawDocId == quantizedDocId;
return quantizedDocId;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
return quantizedVectorValues.scorer(query);
}
}
}