blob: ca9f5ef7c94b6e2361182c474f26577888661219 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
/**
* Buffers up pending vector value(s) per doc, then flushes when segment flushes. Used for {@code
* SimpleTextKnnVectorsWriter} and for vectors writers before v 9.3 .
*
* @lucene.experimental
*/
public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
private final List<FieldWriter> fields = new ArrayList<>();
/** Sole constructor */
protected BufferingKnnVectorsWriter() {}
@Override
public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter fieldData : fields) {
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {
throw new UnsupportedOperationException();
}
@Override
public void checkIntegrity() {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(
fieldData.docsWithField,
fieldData.vectors,
fieldData.fieldInfo.getVectorDimension());
return sortMap != null
? new VectorValues.SortingVectorValues(vectorValues, sortMap)
: vectorValues;
}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
};
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
}
}
@Override
public long ramBytesUsed() {
long total = 0;
for (FieldWriter field : fields) {
total += field.ramBytesUsed();
}
return total;
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnVectorsReader knnVectorsReader =
new KnnVectorsReader() {
@Override
public long ramBytesUsed() {
return 0;
}
@Override
public void close() {}
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
}
@Override
public void checkIntegrity() {}
};
writeField(fieldInfo, knnVectorsReader, mergeState.segmentInfo.maxDoc());
}
/** Write the provided field */
protected abstract void writeField(
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<float[]> vectors;
private int lastDocID = -1;
public FieldWriter(FieldInfo fieldInfo) {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
}
@Override
public void addValue(int docID, Object value) {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
float[] vectorValue =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> (float[]) value;
case BYTE -> bytesToFloats((BytesRef) value);
};
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
lastDocID = docID;
}
private float[] bytesToFloats(BytesRef b) {
// This is used only by SimpleTextKnnVectorsWriter
float[] floats = new float[dim];
for (int i = 0; i < dim; i++) {
floats[i] = b.bytes[i + b.offset];
}
return floats;
}
@Override
public float[] copyValue(float[] vectorValue) {
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
}
@Override
public long ramBytesUsed() {
if (vectors.size() == 0) return 0;
return docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * dim * Float.BYTES;
}
}
private static class BufferedVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
final DocsWithFieldSet docsWithField;
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<float[]> vectors;
final int dimension;
final ByteBuffer buffer;
final BytesRef binaryValue;
final ByteBuffer raBuffer;
final BytesRef raBinaryValue;
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
this.docsWithField = docsWithField;
this.vectors = vectors;
this.dimension = dimension;
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
binaryValue = new BytesRef(buffer.array());
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
raBinaryValue = new BytesRef(raBuffer.array());
docsWithFieldIter = docsWithField.iterator();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new BufferedVectorValues(docsWithField, vectors, dimension);
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return vectors.size();
}
@Override
public BytesRef binaryValue() {
buffer.asFloatBuffer().put(vectorValue());
return binaryValue;
}
@Override
public BytesRef binaryValue(int targetOrd) {
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
return raBinaryValue;
}
@Override
public float[] vectorValue() {
return vectors.get(ord);
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}
@Override
public int docID() {
return docsWithFieldIter.docID();
}
@Override
public int nextDoc() throws IOException {
int docID = docsWithFieldIter.nextDoc();
if (docID != NO_MORE_DOCS) {
++ord;
}
return docID;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return docsWithFieldIter.cost();
}
}
}