LUCENE-9905: PerFieldVectorFormat (#114)
* LUCENE-9905: PerFieldVectorFormat
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorFormat.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorFormat.java
index 1952683..304f660 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorFormat.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorFormat.java
@@ -33,6 +33,10 @@
*/
public final class SimpleTextVectorFormat extends VectorFormat {
+ public SimpleTextVectorFormat() {
+ super("SimpleTextVectorFormat");
+ }
+
@Override
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new SimpleTextVectorWriter(state);
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/VectorFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/VectorFormat.java
index 44ae27c..ebb9976 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/VectorFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/VectorFormat.java
@@ -23,15 +23,51 @@
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
+import org.apache.lucene.util.NamedSPILoader;
/**
* Encodes/decodes per-document vector and any associated indexing structures required to support
* nearest-neighbor search
*/
-public abstract class VectorFormat {
+public abstract class VectorFormat implements NamedSPILoader.NamedSPI {
+
+ /**
+ * This static holder class prevents classloading deadlock by delaying init of doc values formats
+ * until needed.
+ */
+ private static final class Holder {
+ private static final NamedSPILoader<VectorFormat> LOADER =
+ new NamedSPILoader<>(VectorFormat.class);
+
+ private Holder() {}
+
+ static NamedSPILoader<VectorFormat> getLoader() {
+ if (LOADER == null) {
+ throw new IllegalStateException(
+ "You tried to lookup a VectorFormat name before all formats could be initialized. "
+ + "This likely happens if you call VectorFormat#forName from a VectorFormat's ctor.");
+ }
+ return LOADER;
+ }
+ }
+
+ private final String name;
/** Sole constructor */
- protected VectorFormat() {}
+ protected VectorFormat(String name) {
+ NamedSPILoader.checkServiceName(name);
+ this.name = name;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ /** looks up a format by name */
+ public static VectorFormat forName(String name) {
+ return Holder.getLoader().lookup(name);
+ }
/** Returns a {@link VectorWriter} to write the vectors to the index. */
public abstract VectorWriter fieldsWriter(SegmentWriteState state) throws IOException;
@@ -44,7 +80,7 @@
* support vectors.
*/
public static final VectorFormat EMPTY =
- new VectorFormat() {
+ new VectorFormat("EMPTY") {
@Override
public VectorWriter fieldsWriter(SegmentWriteState state) {
throw new UnsupportedOperationException(
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90Codec.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90Codec.java
index eb7818e..a0fbd56 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90Codec.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90Codec.java
@@ -32,6 +32,7 @@
import org.apache.lucene.codecs.VectorFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
+import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
/**
* Implements the Lucene 9.0 index format
@@ -84,7 +85,13 @@
}
};
- private final VectorFormat vectorFormat = new Lucene90HnswVectorFormat();
+ private final VectorFormat vectorFormat =
+ new PerFieldVectorFormat() {
+ @Override
+ public VectorFormat getVectorFormatForField(String field) {
+ return new Lucene90HnswVectorFormat();
+ }
+ };
private final StoredFieldsFormat storedFieldsFormat;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorFormat.java
index 86e6341..33bca4c 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorFormat.java
@@ -77,7 +77,9 @@
static final int VERSION_CURRENT = VERSION_START;
/** Sole constructor */
- public Lucene90HnswVectorFormat() {}
+ public Lucene90HnswVectorFormat() {
+ super("Lucene90HnswVectorFormat");
+ }
@Override
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
index aaff806..fa7867f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
@@ -28,7 +28,6 @@
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
-import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
@@ -57,7 +56,7 @@
* @lucene.experimental
*/
public abstract class PerFieldDocValuesFormat extends DocValuesFormat {
- /** Name of this {@link PostingsFormat}. */
+ /** Name of this {@link DocValuesFormat}. */
public static final String PER_FIELD_NAME = "PerFieldDV40";
/** {@link FieldInfo} attribute name used to store the format name for each field. */
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java
new file mode 100644
index 0000000..e834722
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldVectorFormat.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.codecs.perfield;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.ServiceLoader;
+import java.util.TreeMap;
+import org.apache.lucene.codecs.VectorFormat;
+import org.apache.lucene.codecs.VectorReader;
+import org.apache.lucene.codecs.VectorWriter;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.util.IOUtils;
+
+/**
+ * Enables per field numeric vector support.
+ *
+ * <p>Note, when extending this class, the name ({@link #getName}) is written into the index. In
+ * order for the field to be read, the name must resolve to your implementation via {@link
+ * #forName(String)}. This method uses Java's {@link ServiceLoader Service Provider Interface} to
+ * resolve format names.
+ *
+ * <p>Files written by each numeric vectors format have an additional suffix containing the format
+ * name. For example, in a per-field configuration instead of <code>_1.dat</code> filenames would
+ * look like <code>_1_Lucene40_0.dat</code>.
+ *
+ * @see ServiceLoader
+ * @lucene.experimental
+ */
+public abstract class PerFieldVectorFormat extends VectorFormat {
+ /** Name of this {@link VectorFormat}. */
+ public static final String PER_FIELD_NAME = "PerFieldVectors90";
+
+ /** {@link FieldInfo} attribute name used to store the format name for each field. */
+ public static final String PER_FIELD_FORMAT_KEY =
+ PerFieldVectorFormat.class.getSimpleName() + ".format";
+
+ /** {@link FieldInfo} attribute name used to store the segment suffix name for each field. */
+ public static final String PER_FIELD_SUFFIX_KEY =
+ PerFieldVectorFormat.class.getSimpleName() + ".suffix";
+
+ /** Sole constructor. */
+ protected PerFieldVectorFormat() {
+ super(PER_FIELD_NAME);
+ }
+
+ @Override
+ public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new FieldsWriter(state);
+ }
+
+ @Override
+ public VectorReader fieldsReader(SegmentReadState state) throws IOException {
+ return new FieldsReader(state);
+ }
+
+ /**
+ * Returns the numeric vector format that should be used for writing new segments of <code>field
+ * </code>.
+ *
+ * <p>The field to format mapping is written to the index, so this method is only invoked when
+ * writing, not when reading.
+ */
+ public abstract VectorFormat getVectorFormatForField(String field);
+
+ private class FieldsWriter extends VectorWriter {
+ private final Map<VectorFormat, WriterAndSuffix> formats;
+ private final Map<String, Integer> suffixes = new HashMap<>();
+ private final SegmentWriteState segmentWriteState;
+
+ FieldsWriter(SegmentWriteState segmentWriteState) {
+ this.segmentWriteState = segmentWriteState;
+ formats = new HashMap<>();
+ }
+
+ @Override
+ public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
+ getInstance(fieldInfo).writeField(fieldInfo, values);
+ }
+
+ @Override
+ public void finish() throws IOException {
+ for (WriterAndSuffix was : formats.values()) {
+ was.writer.finish();
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(formats.values());
+ }
+
+ private VectorWriter getInstance(FieldInfo field) throws IOException {
+ VectorFormat format = null;
+ String fieldFormatName = field.getAttribute(PER_FIELD_FORMAT_KEY);
+ if (fieldFormatName != null) {
+ format = VectorFormat.forName(fieldFormatName);
+ }
+ if (format == null) {
+ format = getVectorFormatForField(field.name);
+ }
+ if (format == null) {
+ throw new IllegalStateException(
+ "invalid null VectorFormat for field=\"" + field.name + "\"");
+ }
+ final String formatName = format.getName();
+
+ field.putAttribute(PER_FIELD_FORMAT_KEY, formatName);
+ Integer suffix = null;
+
+ WriterAndSuffix writerAndSuffix = formats.get(format);
+ if (writerAndSuffix == null) {
+ // First time we are seeing this format; create a new instance
+
+ String suffixAtt = field.getAttribute(PER_FIELD_SUFFIX_KEY);
+ if (suffixAtt != null) {
+ suffix = Integer.valueOf(suffixAtt);
+ }
+
+ if (suffix == null) {
+ // bump the suffix
+ suffix = suffixes.get(formatName);
+ if (suffix == null) {
+ suffix = 0;
+ } else {
+ suffix = suffix + 1;
+ }
+ }
+ suffixes.put(formatName, suffix);
+
+ String segmentSuffix =
+ getFullSegmentSuffix(
+ segmentWriteState.segmentSuffix, getSuffix(formatName, Integer.toString(suffix)));
+ writerAndSuffix =
+ new WriterAndSuffix(
+ format.fieldsWriter(new SegmentWriteState(segmentWriteState, segmentSuffix)),
+ suffix);
+ formats.put(format, writerAndSuffix);
+ } else {
+ // we've already seen this format, so just grab its suffix
+ assert suffixes.containsKey(formatName);
+ suffix = writerAndSuffix.suffix;
+ }
+
+ field.putAttribute(PER_FIELD_SUFFIX_KEY, Integer.toString(suffix));
+ return writerAndSuffix.writer;
+ }
+ }
+
+ /** VectorReader that can wrap multiple delegate readers, selected by field. */
+ public static class FieldsReader extends VectorReader {
+
+ private final Map<String, VectorReader> fields = new TreeMap<>();
+
+ /**
+ * Create a FieldsReader over a segment, opening VectorReaders for each VectorFormat specified
+ * by the indexed numeric vector fields.
+ *
+ * @param readState defines the fields
+ * @throws IOException if one of the delegate readers throws
+ */
+ public FieldsReader(final SegmentReadState readState) throws IOException {
+
+ // Init each unique format:
+ boolean success = false;
+ Map<String, VectorReader> formats = new HashMap<>();
+ try {
+ // Read field name -> format name
+ for (FieldInfo fi : readState.fieldInfos) {
+ if (fi.hasVectorValues()) {
+ final String fieldName = fi.name;
+ final String formatName = fi.getAttribute(PER_FIELD_FORMAT_KEY);
+ if (formatName != null) {
+ // null formatName means the field is in fieldInfos, but has no vectors!
+ final String suffix = fi.getAttribute(PER_FIELD_SUFFIX_KEY);
+ if (suffix == null) {
+ throw new IllegalStateException(
+ "missing attribute: " + PER_FIELD_SUFFIX_KEY + " for field: " + fieldName);
+ }
+ VectorFormat format = VectorFormat.forName(formatName);
+ String segmentSuffix =
+ getFullSegmentSuffix(readState.segmentSuffix, getSuffix(formatName, suffix));
+ if (!formats.containsKey(segmentSuffix)) {
+ formats.put(
+ segmentSuffix,
+ format.fieldsReader(new SegmentReadState(readState, segmentSuffix)));
+ }
+ fields.put(fieldName, formats.get(segmentSuffix));
+ }
+ }
+ }
+ success = true;
+ } finally {
+ if (!success) {
+ IOUtils.closeWhileHandlingException(formats.values());
+ }
+ }
+ }
+
+ /**
+ * Return the underlying VectorReader for the given field
+ *
+ * @param field the name of a numeric vector field
+ */
+ public VectorReader getFieldReader(String field) {
+ return fields.get(field);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ for (VectorReader reader : fields.values()) {
+ reader.checkIntegrity();
+ }
+ }
+
+ @Override
+ public VectorValues getVectorValues(String field) throws IOException {
+ VectorReader vectorReader = fields.get(field);
+ if (vectorReader == null) {
+ return null;
+ } else {
+ return vectorReader.getVectorValues(field);
+ }
+ }
+
+ @Override
+ public TopDocs search(String field, float[] target, int k, int fanout) throws IOException {
+ VectorReader vectorReader = fields.get(field);
+ if (vectorReader == null) {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ } else {
+ return vectorReader.search(field, target, k, fanout);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(fields.values());
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = 0;
+ for (VectorReader reader : fields.values()) {
+ total += reader.ramBytesUsed();
+ }
+ return total;
+ }
+ }
+
+ static String getSuffix(String formatName, String suffix) {
+ return formatName + "_" + suffix;
+ }
+
+ static String getFullSegmentSuffix(String outerSegmentSuffix, String segmentSuffix) {
+ if (outerSegmentSuffix.length() == 0) {
+ return segmentSuffix;
+ } else {
+ return outerSegmentSuffix + "_" + segmentSuffix;
+ }
+ }
+
+ private static class WriterAndSuffix implements Closeable {
+ final VectorWriter writer;
+ final int suffix;
+
+ WriterAndSuffix(VectorWriter writer, int suffix) {
+ this.writer = writer;
+ this.suffix = suffix;
+ }
+
+ @Override
+ public void close() throws IOException {
+ writer.close();
+ }
+ }
+}
diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat
new file mode 100644
index 0000000..0242c3a
--- /dev/null
+++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+org.apache.lucene.codecs.lucene90.Lucene90HnswVectorFormat
\ No newline at end of file
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java
new file mode 100644
index 0000000..0ae1e5b
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldVectorFormat.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.perfield;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Random;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.VectorFormat;
+import org.apache.lucene.codecs.asserting.AssertingCodec;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.VectorField;
+import org.apache.lucene.index.BaseVectorFormatTestCase;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.RandomCodec;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.TestUtil;
+
+/** Basic tests of PerFieldDocValuesFormat */
+public class TestPerFieldVectorFormat extends BaseVectorFormatTestCase {
+ private Codec codec;
+
+ @Override
+ public void setUp() throws Exception {
+ codec = new RandomCodec(new Random(random().nextLong()), Collections.emptySet());
+ super.setUp();
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return codec;
+ }
+
+ // just a simple trivial test
+ public void testTwoFieldsTwoFormats() throws IOException {
+ Analyzer analyzer = new MockAnalyzer(random());
+
+ try (Directory directory = newDirectory()) {
+ // we don't use RandomIndexWriter because it might add more values than we expect !!!!1
+ IndexWriterConfig iwc = newIndexWriterConfig(analyzer);
+ final VectorFormat fast = TestUtil.getDefaultVectorFormat();
+ final VectorFormat slow = VectorFormat.forName("Asserting");
+ iwc.setCodec(
+ new AssertingCodec() {
+ @Override
+ public VectorFormat getVectorFormatForField(String field) {
+ if ("v1".equals(field)) {
+ return fast;
+ } else {
+ return slow;
+ }
+ }
+ });
+ try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
+ Document doc = new Document();
+ doc.add(newTextField("id", "1", Field.Store.YES));
+ doc.add(new VectorField("v1", new float[] {1, 2, 3}));
+ iwriter.addDocument(doc);
+ doc = new Document();
+ doc.add(newTextField("id", "2", Field.Store.YES));
+ doc.add(new VectorField("v2", new float[] {4, 5, 6}));
+ iwriter.addDocument(doc);
+ }
+
+ // Now search the index:
+ try (IndexReader ireader = DirectoryReader.open(directory)) {
+ TopDocs hits1 =
+ ireader
+ .leaves()
+ .get(0)
+ .reader()
+ .searchNearestVectors("v1", new float[] {1, 2, 3}, 10, 1);
+ assertEquals(1, hits1.scoreDocs.length);
+ TopDocs hits2 =
+ ireader
+ .leaves()
+ .get(0)
+ .reader()
+ .searchNearestVectors("v2", new float[] {1, 2, 3}, 10, 1);
+ assertEquals(1, hits2.scoreDocs.length);
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index 679bd46..6a5b272 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -28,6 +28,7 @@
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
+import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
@@ -171,9 +172,11 @@
iw.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(dir)) {
+ PerFieldVectorFormat.FieldsReader perFieldReader =
+ (PerFieldVectorFormat.FieldsReader)
+ ((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
Lucene90HnswVectorReader vectorReader =
- ((Lucene90HnswVectorReader)
- ((CodecReader) getOnlyLeafReader(reader)).getVectorReader());
+ (Lucene90HnswVectorReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
}
}
@@ -310,11 +313,13 @@
for (LeafReaderContext ctx : dr.leaves()) {
LeafReader reader = ctx.reader();
VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD);
- Lucene90HnswVectorReader vectorReader =
- ((Lucene90HnswVectorReader) ((CodecReader) reader).getVectorReader());
- if (vectorReader == null) {
+ PerFieldVectorFormat.FieldsReader perFieldReader =
+ (PerFieldVectorFormat.FieldsReader) ((CodecReader) reader).getVectorReader();
+ if (perFieldReader == null) {
continue;
}
+ Lucene90HnswVectorReader vectorReader =
+ (Lucene90HnswVectorReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD);
assertEquals((vectorValues == null), (graphValues == null));
if (vectorValues == null) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
index edd13f1..da7b60e 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
@@ -26,6 +26,7 @@
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
+import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
@@ -90,7 +91,10 @@
assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values);
KnnGraphValues graphValues =
- ((Lucene90HnswVectorReader) ((CodecReader) ctx.reader()).getVectorReader())
+ ((Lucene90HnswVectorReader)
+ ((PerFieldVectorFormat.FieldsReader)
+ ((CodecReader) ctx.reader()).getVectorReader())
+ .getFieldReader("field"))
.getGraphValues("field");
assertGraphEqual(hnsw, graphValues, nVec);
}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingCodec.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingCodec.java
index 91260d0..5d1dcf3 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingCodec.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingCodec.java
@@ -24,6 +24,7 @@
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat;
+import org.apache.lucene.codecs.VectorFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
import org.apache.lucene.util.TestUtil;
@@ -67,6 +68,7 @@
private final PostingsFormat defaultFormat = new AssertingPostingsFormat();
private final DocValuesFormat defaultDVFormat = new AssertingDocValuesFormat();
private final PointsFormat pointsFormat = new AssertingPointsFormat();
+ private final VectorFormat defaultVectorFormat = new AssertingVectorFormat();
public AssertingCodec() {
super("Asserting", TestUtil.getDefaultCodec());
@@ -108,6 +110,11 @@
}
@Override
+ public VectorFormat vectorFormat() {
+ return defaultVectorFormat;
+ }
+
+ @Override
public String toString() {
return "Asserting(" + delegate + ")";
}
@@ -130,4 +137,8 @@
public DocValuesFormat getDocValuesFormatForField(String field) {
return defaultDVFormat;
}
+
+ public VectorFormat getVectorFormatForField(String field) {
+ return defaultVectorFormat;
+ }
}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java
new file mode 100644
index 0000000..bae5529
--- /dev/null
+++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.codecs.asserting;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.VectorFormat;
+import org.apache.lucene.codecs.VectorReader;
+import org.apache.lucene.codecs.VectorWriter;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.TestUtil;
+
+/** Wraps the default VectorFormat and provides additional assertions. */
+public class AssertingVectorFormat extends VectorFormat {
+
+ private final VectorFormat delegate = TestUtil.getDefaultVectorFormat();
+
+ public AssertingVectorFormat() {
+ super("Asserting");
+ }
+
+ @Override
+ public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new AssertingVectorWriter(delegate.fieldsWriter(state));
+ }
+
+ @Override
+ public VectorReader fieldsReader(SegmentReadState state) throws IOException {
+ return new AssertingVectorReader(delegate.fieldsReader(state));
+ }
+
+ static class AssertingVectorWriter extends VectorWriter {
+ final VectorWriter delegate;
+
+ AssertingVectorWriter(VectorWriter delegate) {
+ assert delegate != null;
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
+ assert fieldInfo != null;
+ assert values != null;
+ delegate.writeField(fieldInfo, values);
+ }
+
+ @Override
+ public void finish() throws IOException {
+ delegate.finish();
+ }
+
+ @Override
+ public void close() throws IOException {
+ delegate.close();
+ }
+ }
+
+ static class AssertingVectorReader extends VectorReader {
+ final VectorReader delegate;
+
+ AssertingVectorReader(VectorReader delegate) {
+ assert delegate != null;
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ delegate.checkIntegrity();
+ }
+
+ @Override
+ public VectorValues getVectorValues(String field) throws IOException {
+ VectorValues values = delegate.getVectorValues(field);
+ if (values != null) {
+ assert values.docID() == -1;
+ assert values.size() > 0;
+ assert values.dimension() > 0;
+ assert values.similarityFunction() != null;
+ }
+ return values;
+ }
+
+ @Override
+ public TopDocs search(String field, float[] target, int k, int fanout) throws IOException {
+ TopDocs hits = delegate.search(field, target, k, fanout);
+ assert hits != null;
+ assert hits.scoreDocs.length <= k;
+ return hits;
+ }
+
+ @Override
+ public void close() throws IOException {
+ delegate.close();
+ delegate.close();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ return delegate.ramBytesUsed();
+ }
+ }
+}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java
index 3e0b182..86b9b9e 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java
@@ -839,7 +839,6 @@
IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
- float[][] values = new float[numDoc][];
float[][] id2value = new float[numDoc][];
int[] id2ord = new int[numDoc];
for (int i = 0; i < numDoc; i++) {
@@ -851,7 +850,6 @@
} else {
value = null;
}
- values[i] = value;
id2value[id] = value;
id2ord[id] = i;
add(iw, fieldName, id, value, VectorValues.SimilarityFunction.EUCLIDEAN);
diff --git a/lucene/test-framework/src/java/org/apache/lucene/util/TestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/util/TestUtil.java
index 726ffe1..239b9ea 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/util/TestUtil.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/util/TestUtil.java
@@ -50,11 +50,13 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.PostingsFormat;
+import org.apache.lucene.codecs.VectorFormat;
import org.apache.lucene.codecs.asserting.AssertingCodec;
import org.apache.lucene.codecs.blockterms.LuceneFixedGap;
import org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
+import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
@@ -1294,6 +1296,13 @@
return true;
}
+ /**
+ * Returns the actual default vector format (e.g. LuceneMNVectorFormat for this version of Lucene.
+ */
+ public static VectorFormat getDefaultVectorFormat() {
+ return new Lucene90HnswVectorFormat();
+ }
+
public static boolean anyFilesExceptWriteLock(Directory dir) throws IOException {
String[] files = dir.listAll();
if (files.length > 1 || (files.length == 1 && !files[0].equals("write.lock"))) {
diff --git a/lucene/test-framework/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat b/lucene/test-framework/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat
new file mode 100644
index 0000000..526d78f
--- /dev/null
+++ b/lucene/test-framework/src/resources/META-INF/services/org.apache.lucene.codecs.VectorFormat
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+org.apache.lucene.codecs.asserting.AssertingVectorFormat
\ No newline at end of file