blob: da7b60e434cf334a8b43258b94d4c79e60029a35 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
/** Tests HNSW KNN graphs */
public class TestHnsw extends LuceneTestCase {
// test writing out and reading in a graph gives the same graph
public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1;
int nDoc = random().nextInt(100) + 1;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
HnswGraphBuilder builder = new HnswGraphBuilder(vectors);
HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
try (Directory dir = newDirectory()) {
int nVec = 0, indexedDoc = 0;
// Don't merge randomly, create a single segment because we rely on the docid ordering for
// this test
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(Codec.forName("Lucene90"));
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
while (v2.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < v2.docID()) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(new VectorField("field", v2.vectorValue(), v2.similarityFunction));
doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc);
nVec++;
indexedDoc++;
}
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
VectorValues values = ctx.reader().getVectorValues("field");
assertEquals(vectors.similarityFunction, values.similarityFunction());
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc());
assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values);
KnnGraphValues graphValues =
((Lucene90HnswVectorReader)
((PerFieldVectorFormat.FieldsReader)
((CodecReader) ctx.reader()).getVectorReader())
.getFieldReader("field"))
.getGraphValues("field");
assertGraphEqual(hnsw, graphValues, nVec);
}
}
}
}
// Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
public void testAknnDiverse() throws IOException {
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
HnswGraph.search(new float[] {1, 0}, 10, 5, vectors.randomAccess(), hnsw, random());
int sum = 0;
for (int node : nn.nodes()) {
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
// 45
assertTrue("sum(result docs)=" + sum, sum < 75);
for (int i = 0; i < nDoc; i++) {
NeighborArray neighbors = hnsw.getNeighbors(i);
int[] nodes = neighbors.node;
for (int j = 0; j < neighbors.size(); j++) {
// all neighbors should be valid node ids.
assertTrue(nodes[j] < nDoc);
}
}
}
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
// any float > -MAX_VALUE is in bounds
assertFalse(max.check(f));
// f is now the bound (minus some delta)
max.update(f);
assertFalse(max.check(f)); // f is not out of bounds
assertFalse(max.check(f + 1)); // anything greater than f is in bounds
assertTrue(max.check(f - 1e-5f)); // delta is zero initially
}
public void testBoundsCheckerMin() {
BoundsChecker min = BoundsChecker.create(true);
float f = random().nextFloat() - 0.5f;
// any float < MAX_VALUE is in bounds
assertFalse(min.check(f));
// f is now the bound (minus some delta)
min.update(f);
assertFalse(min.check(f)); // f is not out of bounds
assertFalse(min.check(f - 1)); // anything less than f is in bounds
assertTrue(min.check(f + 1e-5f)); // delta is zero initially
}
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
expectThrows(
IllegalArgumentException.class,
() -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
expectThrows(
IllegalArgumentException.class,
() -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
}
public void testDiversity() throws IOException {
// Some carefully checked test cases with simple 2d vectors on the unit circle:
MockVectorValues vectors =
new MockVectorValues(
VectorValues.SimilarityFunction.DOT_PRODUCT,
new float[][] {
unitVector2d(0.5),
unitVector2d(0.75),
unitVector2d(0.2),
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
});
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(vectors.vectorValue(1));
builder.addGraphNode(vectors.vectorValue(2));
// now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check.
assertNeighbors(builder.hnsw, 0, 1, 2);
assertNeighbors(builder.hnsw, 1, 0);
assertNeighbors(builder.hnsw, 2, 0);
builder.addGraphNode(vectors.vectorValue(3));
assertNeighbors(builder.hnsw, 0, 1, 2);
// we added 3 here
assertNeighbors(builder.hnsw, 1, 0, 3);
assertNeighbors(builder.hnsw, 2, 0);
assertNeighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor
builder.addGraphNode(vectors.vectorValue(4));
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertNeighbors(builder.hnsw, 0, 1, 2);
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
// replace it
assertNeighbors(builder.hnsw, 1, 0, 4);
assertNeighbors(builder.hnsw, 2, 0);
// 1 survives the diversity check
assertNeighbors(builder.hnsw, 3, 1, 4);
assertNeighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(vectors.vectorValue(5));
assertNeighbors(builder.hnsw, 0, 1, 2);
assertNeighbors(builder.hnsw, 1, 0, 5);
assertNeighbors(builder.hnsw, 2, 0);
// even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
assertNeighbors(builder.hnsw, 3, 1, 4);
assertNeighbors(builder.hnsw, 4, 3, 5);
assertNeighbors(builder.hnsw, 5, 1, 4);
}
private void assertNeighbors(HnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(node);
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
Arrays.sort(actual);
assertArrayEquals(
"expected: " + Arrays.toString(expected) + " actual: " + Arrays.toString(actual),
expected,
actual);
}
public void testRandom() throws IOException {
int size = atLeast(100);
int dim = atLeast(10);
int topK = 5;
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
NeighborQueue expected = new NeighborQueue(topK, vectors.similarityFunction.reversed);
for (int j = 0; j < size; j++) {
float[] v = vectors.vectorValue(j);
if (v != null) {
expected.insertWithOverflow(
j, vectors.similarityFunction.compare(query, vectors.vectorValue(j)));
}
}
assertEquals(topK, actual.size());
totalMatches += computeOverlap(actual.nodes(), expected.nodes());
}
double overlap = totalMatches / (double) (100 * topK);
System.out.println("overlap=" + overlap + " totalMatches=" + totalMatches);
assertTrue("overlap=" + overlap, overlap > 0.9);
}
private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
int overlap = 0;
for (int i = 0, j = 0; i < a.length && j < b.length; ) {
if (a[i] == b[j]) {
++overlap;
++i;
++j;
} else if (a[i] > b[j]) {
++j;
} else {
++i;
}
}
return overlap;
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int size;
private final float[] value;
int doc = -1;
CircularVectorValues(int size) {
this.size = size;
value = new float[2];
}
public CircularVectorValues copy() {
return new CircularVectorValues(size);
}
@Override
public SimilarityFunction similarityFunction() {
return SimilarityFunction.DOT_PRODUCT;
}
@Override
public int dimension() {
return 2;
}
@Override
public int size() {
return size;
}
@Override
public float[] vectorValue() {
return vectorValue(doc);
}
@Override
public RandomAccessVectorValues randomAccess() {
return new CircularVectorValues(size);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= 0 && target < size) {
doc = target;
} else {
doc = NO_MORE_DOCS;
}
return doc;
}
@Override
public long cost() {
return size;
}
@Override
public float[] vectorValue(int ord) {
return unitVector2d(ord / (double) size, value);
}
@Override
public BytesRef binaryValue(int ord) {
return null;
}
}
private static float[] unitVector2d(double piRadians) {
return unitVector2d(piRadians, new float[2]);
}
private static float[] unitVector2d(double piRadians, float[] value) {
value[0] = (float) Math.cos(Math.PI * piRadians);
value[1] = (float) Math.sin(Math.PI * piRadians);
return value;
}
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
for (int node = 0; node < size; node++) {
g.seek(node);
h.seek(node);
assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
}
}
private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
Set<Integer> neighbors = new HashSet<>();
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
neighbors.add(n);
}
return neighbors;
}
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
int uDoc, vDoc;
while (true) {
uDoc = u.nextDoc();
vDoc = v.nextDoc();
assertEquals(uDoc, vDoc);
if (uDoc == NO_MORE_DOCS) {
break;
}
assertArrayEquals(
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
}
}
/** Produces random vectors and caches them for random-access. */
static class RandomVectorValues extends MockVectorValues {
RandomVectorValues(int size, int dimension, Random random) {
super(
SimilarityFunction.values()[random.nextInt(SimilarityFunction.values().length - 1) + 1],
createRandomVectors(size, dimension, random));
}
RandomVectorValues(RandomVectorValues other) {
super(other.similarityFunction, other.values);
}
@Override
public RandomVectorValues copy() {
return new RandomVectorValues(this);
}
private static float[][] createRandomVectors(int size, int dimension, Random random) {
float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector(random, dimension);
}
return vectors;
}
}
private static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
}
VectorUtil.l2normalize(vec);
return vec;
}
}