| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.lucene.codecs.lucene90; |
| |
| import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; |
| |
| import java.io.IOException; |
| import java.util.Arrays; |
| import org.apache.lucene.codecs.CodecUtil; |
| import org.apache.lucene.codecs.VectorWriter; |
| import org.apache.lucene.index.FieldInfo; |
| import org.apache.lucene.index.IndexFileNames; |
| import org.apache.lucene.index.RandomAccessVectorValuesProducer; |
| import org.apache.lucene.index.SegmentWriteState; |
| import org.apache.lucene.index.VectorValues; |
| import org.apache.lucene.store.IndexOutput; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.IOUtils; |
| import org.apache.lucene.util.hnsw.HnswGraph; |
| import org.apache.lucene.util.hnsw.HnswGraphBuilder; |
| import org.apache.lucene.util.hnsw.NeighborArray; |
| |
| /** |
| * Writes vector values and knn graphs to index segments. |
| * |
| * @lucene.experimental |
| */ |
| public final class Lucene90VectorWriter extends VectorWriter { |
| |
| private final SegmentWriteState segmentWriteState; |
| private final IndexOutput meta, vectorData, vectorIndex; |
| |
| private boolean finished; |
| |
| Lucene90VectorWriter(SegmentWriteState state) throws IOException { |
| assert state.fieldInfos.hasVectorValues(); |
| segmentWriteState = state; |
| |
| String metaFileName = |
| IndexFileNames.segmentFileName( |
| state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION); |
| |
| String vectorDataFileName = |
| IndexFileNames.segmentFileName( |
| state.segmentInfo.name, |
| state.segmentSuffix, |
| Lucene90VectorFormat.VECTOR_DATA_EXTENSION); |
| |
| String indexDataFileName = |
| IndexFileNames.segmentFileName( |
| state.segmentInfo.name, |
| state.segmentSuffix, |
| Lucene90VectorFormat.VECTOR_INDEX_EXTENSION); |
| |
| boolean success = false; |
| try { |
| meta = state.directory.createOutput(metaFileName, state.context); |
| vectorData = state.directory.createOutput(vectorDataFileName, state.context); |
| vectorIndex = state.directory.createOutput(indexDataFileName, state.context); |
| |
| CodecUtil.writeIndexHeader( |
| meta, |
| Lucene90VectorFormat.META_CODEC_NAME, |
| Lucene90VectorFormat.VERSION_CURRENT, |
| state.segmentInfo.getId(), |
| state.segmentSuffix); |
| CodecUtil.writeIndexHeader( |
| vectorData, |
| Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, |
| Lucene90VectorFormat.VERSION_CURRENT, |
| state.segmentInfo.getId(), |
| state.segmentSuffix); |
| CodecUtil.writeIndexHeader( |
| vectorIndex, |
| Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, |
| Lucene90VectorFormat.VERSION_CURRENT, |
| state.segmentInfo.getId(), |
| state.segmentSuffix); |
| success = true; |
| } finally { |
| if (success == false) { |
| IOUtils.closeWhileHandlingException(this); |
| } |
| } |
| } |
| |
| @Override |
| public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException { |
| long pos = vectorData.getFilePointer(); |
| // write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when |
| // CFS is not used, eg for larger indexes |
| long padding = (4 - (pos & 0x3)) & 0x3; |
| long vectorDataOffset = pos + padding; |
| for (int i = 0; i < padding; i++) { |
| vectorData.writeByte((byte) 0); |
| } |
| // TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index |
| int[] docIds = new int[vectors.size()]; |
| int count = 0; |
| for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) { |
| // write vector |
| writeVectorValue(vectors); |
| docIds[count] = docV; |
| } |
| // count may be < vectors.size() e,g, if some documents were deleted |
| long[] offsets = new long[count]; |
| long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; |
| long vectorIndexOffset = vectorIndex.getFilePointer(); |
| if (vectors.searchStrategy().isHnsw()) { |
| if (vectors instanceof RandomAccessVectorValuesProducer) { |
| writeGraph( |
| vectorIndex, |
| (RandomAccessVectorValuesProducer) vectors, |
| vectorIndexOffset, |
| offsets, |
| count, |
| fieldInfo.getAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY), |
| fieldInfo.getAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY)); |
| } else { |
| throw new IllegalArgumentException( |
| "Indexing an HNSW graph requires a random access vector values, got " + vectors); |
| } |
| } |
| long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; |
| if (vectorDataLength > 0) { |
| writeMeta( |
| fieldInfo, |
| vectorDataOffset, |
| vectorDataLength, |
| vectorIndexOffset, |
| vectorIndexLength, |
| count, |
| docIds); |
| if (vectors.searchStrategy().isHnsw()) { |
| writeGraphOffsets(meta, offsets); |
| } |
| } |
| } |
| |
| private void writeMeta( |
| FieldInfo field, |
| long vectorDataOffset, |
| long vectorDataLength, |
| long indexDataOffset, |
| long indexDataLength, |
| int size, |
| int[] docIds) |
| throws IOException { |
| meta.writeInt(field.number); |
| meta.writeInt(field.getVectorSearchStrategy().ordinal()); |
| meta.writeVLong(vectorDataOffset); |
| meta.writeVLong(vectorDataLength); |
| meta.writeVLong(indexDataOffset); |
| meta.writeVLong(indexDataLength); |
| meta.writeInt(field.getVectorDimension()); |
| meta.writeInt(size); |
| for (int i = 0; i < size; i++) { |
| // TODO: delta-encode, or write as bitset |
| meta.writeVInt(docIds[i]); |
| } |
| } |
| |
| private void writeVectorValue(VectorValues vectors) throws IOException { |
| // write vector value |
| BytesRef binaryValue = vectors.binaryValue(); |
| assert binaryValue.length == vectors.dimension() * Float.BYTES; |
| vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); |
| } |
| |
| private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException { |
| long last = 0; |
| for (long offset : offsets) { |
| out.writeVLong(offset - last); |
| last = offset; |
| } |
| } |
| |
| private void writeGraph( |
| IndexOutput graphData, |
| RandomAccessVectorValuesProducer vectorValues, |
| long graphDataOffset, |
| long[] offsets, |
| int count, |
| String maxConnStr, |
| String beamWidthStr) |
| throws IOException { |
| int maxConn, beamWidth; |
| if (maxConnStr == null) { |
| maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN; |
| } else { |
| try { |
| maxConn = Integer.parseInt(maxConnStr); |
| } catch (NumberFormatException e) { |
| throw new NumberFormatException( |
| "Received non integer value for max-connections parameter of HnswGraphBuilder, value: " |
| + maxConnStr); |
| } |
| } |
| if (beamWidthStr == null) { |
| beamWidth = HnswGraphBuilder.DEFAULT_BEAM_WIDTH; |
| } else { |
| try { |
| beamWidth = Integer.parseInt(beamWidthStr); |
| } catch (NumberFormatException e) { |
| throw new NumberFormatException( |
| "Received non integer value for beam-width parameter of HnswGraphBuilder, value: " |
| + beamWidthStr); |
| } |
| } |
| HnswGraphBuilder hnswGraphBuilder = |
| new HnswGraphBuilder(vectorValues, maxConn, beamWidth, HnswGraphBuilder.randSeed); |
| hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); |
| HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); |
| |
| for (int ord = 0; ord < count; ord++) { |
| // write graph |
| offsets[ord] = graphData.getFilePointer() - graphDataOffset; |
| |
| NeighborArray neighbors = graph.getNeighbors(ord); |
| int size = neighbors.size(); |
| |
| // Destructively modify; it's ok we are discarding it after this |
| int[] nodes = neighbors.node(); |
| Arrays.sort(nodes, 0, size); |
| graphData.writeInt(size); |
| |
| int lastNode = -1; // to make the assertion work? |
| for (int i = 0; i < size; i++) { |
| int node = nodes[i]; |
| assert node > lastNode : "nodes out of order: " + lastNode + "," + node; |
| assert node < offsets.length : "node too large: " + node + ">=" + offsets.length; |
| graphData.writeVInt(node - lastNode); |
| lastNode = node; |
| } |
| } |
| } |
| |
| @Override |
| public void finish() throws IOException { |
| if (finished) { |
| throw new IllegalStateException("already finished"); |
| } |
| finished = true; |
| |
| if (meta != null) { |
| // write end of fields marker |
| meta.writeInt(-1); |
| CodecUtil.writeFooter(meta); |
| } |
| if (vectorData != null) { |
| CodecUtil.writeFooter(vectorData); |
| CodecUtil.writeFooter(vectorIndex); |
| } |
| } |
| |
| @Override |
| public void close() throws IOException { |
| IOUtils.close(meta, vectorData, vectorIndex); |
| } |
| } |