blob: 6a5b2721d845ba6c3c39e7e845b7a2110f9c3385 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.VectorValues.SimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
/** Tests indexing of a knn-graph */
public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
private static int maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
private SimilarityFunction similarityFunction;
@Before
public void setup() {
randSeed = random().nextLong();
if (random().nextBoolean()) {
maxConn = random().nextInt(256) + 3;
}
int similarity = random().nextInt(SimilarityFunction.values().length - 1) + 1;
similarityFunction = SimilarityFunction.values()[similarity];
}
@After
public void cleanup() {
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */
public void testBasic() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw =
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
int numDoc = atLeast(10);
int dimension = atLeast(3);
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
}
add(iw, i, values[i]);
}
assertConsistentGraph(iw, values);
}
}
public void testSingleDocument() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw =
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
float[][] values = new float[][] {new float[] {0, 1, 2}};
add(iw, 0, values[0]);
assertConsistentGraph(iw, values);
iw.commit();
assertConsistentGraph(iw, values);
}
}
/** Verify that the graph properties are preserved when merging */
public void testMerge() throws Exception {
try (Directory dir = newDirectory();
IndexWriter iw =
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
float[][] values = randomVectors(numDoc, dimension);
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
}
add(iw, i, values[i]);
if (random().nextInt(10) == 3) {
iw.commit();
}
}
if (random().nextBoolean()) {
iw.forceMerge(1);
}
assertConsistentGraph(iw, values);
}
}
/**
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
* and merging.
*/
public void testMergeProducesSameGraph() throws Exception {
long seed = random().nextLong();
int numDoc = atLeast(100);
int dimension = atLeast(10);
float[][] values = randomVectors(numDoc, dimension);
int mergePoint = random().nextInt(numDoc);
int[][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
int[][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
assertGraphEquals(singleSegmentGraph, mergedGraph);
}
private void assertGraphEquals(int[][] expected, int[][] actual) {
assertEquals("graph sizes differ", expected.length, actual.length);
for (int i = 0; i < expected.length; i++) {
assertArrayEquals("difference at ord=" + i, expected[i], actual[i]);
}
}
private int[][] getIndexedGraph(float[][] values, int mergePoint, long seed) throws IOException {
HnswGraphBuilder.randSeed = seed;
int[][] graph;
try (Directory dir = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
iwc.setCodec(Codec.forName("Lucene90")); // don't use SimpleTextCodec
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
for (int i = 0; i < values.length; i++) {
add(iw, i, values[i]);
if (i == mergePoint) {
// flush proactively to create a segment
iw.flush();
}
}
iw.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(dir)) {
PerFieldVectorFormat.FieldsReader perFieldReader =
(PerFieldVectorFormat.FieldsReader)
((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
Lucene90HnswVectorReader vectorReader =
(Lucene90HnswVectorReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
}
}
return graph;
}
private float[][] randomVectors(int numDoc, int dimension) {
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
}
}
return values;
}
int[][] copyGraph(KnnGraphValues values) throws IOException {
int size = values.size();
int[][] graph = new int[size][];
int[] scratch = new int[maxConn];
for (int node = 0; node < size; node++) {
int n, count = 0;
values.seek(node);
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
scratch[count++] = n;
// graph[node][i++] = n;
}
graph[node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
}
return graph;
}
/** Verify that searching does something reasonable */
public void testSearch() throws Exception {
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
similarityFunction = SimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(Codec.forName("Lucene90")); // test is not compatible with simpletext
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config)) {
// Add a document for every cartesian point in an NxN square so we can
// easily know which are the nearest neighbors to every point. Insert by iterating
// using a prime number that is not a divisor of N*N so that we will hit each point once,
// and chosen so that points will be inserted in a deterministic
// but somewhat distributed pattern
int n = 5, stepSize = 17;
float[][] values = new float[n * n][];
int index = 0;
for (int i = 0; i < values.length; i++) {
// System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
int x = index % n, y = index / n;
values[i] = new float[] {x, y};
index = (index + stepSize) % (n * n);
add(iw, i, values[i]);
if (i == 13) {
// create 2 segments
iw.commit();
}
}
boolean forceMerge = random().nextBoolean();
// System.out.println("");
if (forceMerge) {
iw.forceMerge(1);
}
assertConsistentGraph(iw, values);
try (DirectoryReader dr = DirectoryReader.open(iw)) {
// results are ordered by score (descending) and docid (ascending);
// This is the insertion order:
// column major, origin at upper left
// 0 15 5 20 10
// 3 18 8 23 13
// 6 21 11 1 16
// 9 24 14 4 19
// 12 2 17 7 22
/* For this small graph the "search" is exhaustive, so this mostly tests the APIs, the
* orientation of the various priority queues, the scoring function, but not so much the
* approximate KNN search algorithm
*/
assertGraphSearch(new int[] {0, 15, 3, 18, 5}, new float[] {0f, 0.1f}, dr);
// Tiebreaking by docid must be done after VectorValues.search.
// assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr);
assertGraphSearch(new int[] {15, 18, 0, 3, 5}, new float[] {0.3f, 0.8f}, dr);
}
}
}
private void assertGraphSearch(int[] expected, float[] vector, IndexReader reader)
throws IOException {
TopDocs results = doKnnSearch(reader, vector, 5);
for (ScoreDoc doc : results.scoreDocs) {
// map docId to insertion id
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
}
assertResults(expected, results);
}
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, 10);
if (ctx.docBase > 0) {
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
doc.doc += ctx.docBase;
}
}
}
return TopDocs.merge(k, results);
}
private void assertResults(int[] expected, TopDocs results) {
assertEquals(results.toString(), expected.length, results.scoreDocs.length);
for (int i = expected.length - 1; i >= 0; i--) {
assertEquals(Arrays.toString(results.scoreDocs), expected[i], results.scoreDocs[i].doc);
}
}
// For each leaf, verify that its graph nodes are 1-1 with vectors, that the vectors are the
// expected values,
// and that the graph is fully connected and symmetric.
// NOTE: when we impose max-fanout on the graph it wil no longer be symmetric, but should still
// be fully connected. Is there any other invariant we can test? Well, we can check that max
// fanout
// is respected. We can test *desirable* properties of the graph like small-world (the graph
// diameter
// should be tightly bounded).
private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException {
int totalGraphDocs = 0;
try (DirectoryReader dr = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx : dr.leaves()) {
LeafReader reader = ctx.reader();
VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD);
PerFieldVectorFormat.FieldsReader perFieldReader =
(PerFieldVectorFormat.FieldsReader) ((CodecReader) reader).getVectorReader();
if (perFieldReader == null) {
continue;
}
Lucene90HnswVectorReader vectorReader =
(Lucene90HnswVectorReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD);
assertEquals((vectorValues == null), (graphValues == null));
if (vectorValues == null) {
continue;
}
int[][] graph = new int[reader.maxDoc()][];
boolean foundOrphan = false;
int graphSize = 0;
for (int i = 0; i < reader.maxDoc(); i++) {
int nextDocWithVectors = vectorValues.advance(i);
// System.out.println("advanced to " + nextDocWithVectors);
while (i < nextDocWithVectors && i < reader.maxDoc()) {
int id = Integer.parseInt(reader.document(i).get("id"));
assertNull("document " + id + " has no vector, but was expected to", values[id]);
++i;
}
if (nextDocWithVectors == NO_MORE_DOCS) {
break;
}
int id = Integer.parseInt(reader.document(i).get("id"));
graphValues.seek(graphSize);
// documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue();
assertArrayEquals(
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
values[id],
scratch,
0f);
// We collect neighbors for analysis below
List<Integer> friends = new ArrayList<>();
int arc;
while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
friends.add(arc);
}
if (friends.size() == 0) {
// System.out.printf("knngraph @%d is singleton (advance returns %d)\n", i,
// nextWithNeighbors);
foundOrphan = true;
} else {
// NOTE: these friends are dense ordinals, not docIds.
int[] friendCopy = new int[friends.size()];
for (int j = 0; j < friends.size(); j++) {
friendCopy[j] = friends.get(j);
}
graph[graphSize] = friendCopy;
// System.out.printf("knngraph @%d => %s\n", i, Arrays.toString(graph[i]));
}
graphSize++;
}
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
if (foundOrphan) {
assertEquals("graph is not fully connected", 1, graphSize);
} else {
assertTrue(
"Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
}
if (maxConn > graphSize) {
// assert that the graph in each leaf is connected
assertConnected(graph);
} else {
// assert that max-connections was respected
assertMaxConn(graph, maxConn);
}
totalGraphDocs += graphSize;
}
}
int expectedCount = 0;
for (float[] friends : values) {
if (friends != null) {
++expectedCount;
}
}
assertEquals(expectedCount, totalGraphDocs);
}
private void assertMaxConn(int[][] graph, int maxConn) {
for (int[] ints : graph) {
if (ints != null) {
assert (ints.length <= maxConn);
for (int k : ints) {
assertNotNull(graph[k]);
}
}
}
}
private void assertConnected(int[][] graph) {
// every node in the graph is reachable from every other node
Set<Integer> visited = new HashSet<>();
List<Integer> queue = new LinkedList<>();
int count = 0;
for (int[] entry : graph) {
if (entry != null) {
if (queue.isEmpty()) {
queue.add(entry[0]); // start from any node
// System.out.println("start at " + entry[0]);
}
++count;
}
}
while (queue.isEmpty() == false) {
int i = queue.remove(0);
assertNotNull("expected neighbors of " + i, graph[i]);
visited.add(i);
for (int j : graph[i]) {
if (visited.contains(j) == false) {
// System.out.println(" ... " + j);
queue.add(j);
}
}
}
for (int i = 0; i < count; i++) {
assertTrue("Attempted to walk entire graph but never visited " + i, visited.contains(i));
}
// we visited each node exactly once
assertEquals(
"Attempted to walk entire graph but only visited " + visited.size(), count, visited.size());
}
private void add(IndexWriter iw, int id, float[] vector) throws IOException {
add(iw, id, vector, similarityFunction);
}
private void add(IndexWriter iw, int id, float[] vector, SimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {
FieldType fieldType =
VectorField.createHnswType(
vector.length, similarityFunction, maxConn, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, fieldType));
}
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));
doc.add(new SortedDocValuesField("id", new BytesRef(idString)));
// XSSystem.out.println("add " + idString + " " + Arrays.toString(vector));
iw.updateDocument(new Term("id", idString), doc);
}
}