blob: 94febfb0f6c68013257690688344635a6de76b22 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
/**
* For testing indexing and search performance of a knn-graph
*
* <p>java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search
* .../vectors.bin
*/
public class KnnGraphTester {
private static final String KNN_FIELD = "knn";
private static final String ID_FIELD = "id";
private static final VectorValues.SearchStrategy SEARCH_STRATEGY =
VectorValues.SearchStrategy.DOT_PRODUCT_HNSW;
private int numDocs;
private int dim;
private int topK;
private int warmCount;
private int numIters;
private int fanout;
private Path indexPath;
private boolean quiet;
private boolean reindex;
private boolean forceMerge;
private int reindexTimeMsec;
private int beamWidth;
private int maxConn;
@SuppressForbidden(reason = "uses Random()")
private KnnGraphTester() {
// set defaults
numDocs = 1000;
numIters = 1000;
dim = 256;
topK = 100;
warmCount = 1000;
fanout = topK;
}
public static void main(String... args) throws Exception {
new KnnGraphTester().run(args);
}
private void run(String... args) throws Exception {
String operation = null;
Path docVectorsPath = null, queryPath = null, outputPath = null;
for (int iarg = 0; iarg < args.length; iarg++) {
String arg = args[iarg];
switch (arg) {
case "-search":
case "-check":
case "-stats":
case "-dump":
if (operation != null) {
throw new IllegalArgumentException(
"Specify only one operation, not both " + arg + " and " + operation);
}
operation = arg;
if (operation.equals("-search")) {
if (iarg == args.length - 1) {
throw new IllegalArgumentException(
"Operation " + arg + " requires a following pathname");
}
queryPath = Paths.get(args[++iarg]);
}
break;
case "-fanout":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-fanout requires a following number");
}
fanout = Integer.parseInt(args[++iarg]);
break;
case "-beamWidthIndex":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
}
beamWidth = Integer.parseInt(args[++iarg]);
break;
case "-maxConn":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-maxConn requires a following number");
}
maxConn = Integer.parseInt(args[++iarg]);
break;
case "-dim":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-dim requires a following number");
}
dim = Integer.parseInt(args[++iarg]);
break;
case "-ndoc":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-ndoc requires a following number");
}
numDocs = Integer.parseInt(args[++iarg]);
break;
case "-niter":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-niter requires a following number");
}
numIters = Integer.parseInt(args[++iarg]);
break;
case "-reindex":
reindex = true;
break;
case "-topK":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-topK requires a following number");
}
topK = Integer.parseInt(args[++iarg]);
break;
case "-out":
outputPath = Paths.get(args[++iarg]);
break;
case "-warm":
warmCount = Integer.parseInt(args[++iarg]);
break;
case "-docs":
docVectorsPath = Paths.get(args[++iarg]);
break;
case "-forceMerge":
forceMerge = true;
break;
case "-quiet":
quiet = true;
break;
default:
throw new IllegalArgumentException("unknown argument " + arg);
// usage();
}
}
if (operation == null && reindex == false) {
usage();
}
indexPath = Paths.get(formatIndexPath(docVectorsPath));
if (reindex) {
if (docVectorsPath == null) {
throw new IllegalArgumentException("-docs argument is required when indexing");
}
reindexTimeMsec = createIndex(docVectorsPath, indexPath);
if (forceMerge) {
forceMerge();
}
}
if (operation != null) {
switch (operation) {
case "-search":
if (docVectorsPath == null) {
throw new IllegalArgumentException("missing -docs arg");
}
if (outputPath != null) {
testSearch(indexPath, queryPath, outputPath, null);
} else {
testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
}
break;
case "-dump":
dumpGraph(docVectorsPath);
break;
case "-stats":
printFanoutHist(indexPath);
break;
}
}
}
private String formatIndexPath(Path docsPath) {
return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index";
}
@SuppressForbidden(reason = "Prints stuff")
private void printFanoutHist(Path indexPath) throws IOException {
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
// int[] globalHist = new int[reader.maxDoc()];
for (LeafReaderContext context : reader.leaves()) {
LeafReader leafReader = context.reader();
KnnGraphValues knnValues =
((Lucene90VectorReader) ((CodecReader) leafReader).getVectorReader())
.getGraphValues(KNN_FIELD);
System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
printGraphFanout(knnValues, leafReader.maxDoc());
}
}
}
private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(values.vectorValue(i));
System.out.println("\nITERATION " + i);
dumpGraph(builder.hnsw);
}
}
}
private void dumpGraph(HnswGraph hnsw) {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(i);
System.out.printf(Locale.ROOT, "%5d", i);
NeighborArray sorted = new NeighborArray(neighbors.size());
for (int j = 0; j < neighbors.size(); j++) {
int node = neighbors.node[j];
float score = neighbors.score[j];
sorted.add(node, score);
}
new NeighborArraySorter(sorted).sort(0, sorted.size());
for (int j = 0; j < sorted.size(); j++) {
System.out.printf(Locale.ROOT, " [%d, %.4f]", sorted.node[j], sorted.score[j]);
}
System.out.println();
}
}
@SuppressForbidden(reason = "Prints stuff")
private void forceMerge() throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("Force merge index in " + indexPath);
try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) {
iw.forceMerge(1);
}
}
@SuppressForbidden(reason = "Prints stuff")
private void printGraphFanout(KnnGraphValues knnValues, int numDocs) throws IOException {
int min = Integer.MAX_VALUE, max = 0, total = 0;
int count = 0;
int[] leafHist = new int[numDocs];
for (int node = 0; node < numDocs; node++) {
knnValues.seek(node);
int n = 0;
while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
++n;
}
++leafHist[n];
max = Math.max(max, n);
min = Math.min(min, n);
if (n > 0) {
++count;
total += n;
}
}
System.out.printf(
"Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n",
count, min, total / (float) count, max);
printHist(leafHist, max, count, 10);
}
@SuppressForbidden(reason = "Prints stuff")
private void printHist(int[] hist, int max, int count, int nbuckets) {
System.out.print("%");
for (int i = 0; i <= nbuckets; i++) {
System.out.printf("%4d", i * 100 / nbuckets);
}
System.out.printf("\n %4d", hist[0]);
int total = 0, ibucket = 1;
for (int i = 1; i <= max && ibucket <= nbuckets; i++) {
total += hist[i];
while (total >= count * ibucket / nbuckets) {
System.out.printf("%4d", i);
++ibucket;
}
}
System.out.println();
}
@SuppressForbidden(reason = "Prints stuff")
private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn)
throws IOException {
TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) {
FloatBuffer targets =
q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] target = new float[dim];
if (quiet == false) {
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
}
long start;
ThreadMXBean bean = ManagementFactory.getThreadMXBean();
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
numDocs = reader.maxDoc();
for (int i = 0; i < warmCount; i++) {
// warm up
targets.get(target);
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
}
targets.position(0);
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
targets.get(target);
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
}
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) {
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
}
}
}
if (quiet == false) {
System.out.println(
"completed "
+ numIters
+ " searches in "
+ elapsed
+ " ms: "
+ ((1000 * numIters) / elapsed)
+ " QPS "
+ "CPU time="
+ totalCpuTime
+ "ms");
}
}
if (outputPath != null) {
ByteBuffer buf = ByteBuffer.allocate(4);
IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
try (OutputStream out = Files.newOutputStream(outputPath)) {
for (int i = 0; i < numIters; i++) {
for (ScoreDoc doc : results[i].scoreDocs) {
ibuf.position(0);
ibuf.put(doc.doc);
out.write(buf.array());
}
}
}
} else {
if (quiet == false) {
System.out.println("checking results");
}
float recall = checkResults(results, nn);
totalVisited /= numIters;
System.out.printf(
Locale.ROOT,
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
recall,
totalCpuTime / (float) numIters,
numDocs,
fanout,
maxConn,
beamWidth,
totalVisited,
reindexTimeMsec);
}
}
private static TopDocs doKnnSearch(
IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
results[ctx.ord] = ctx.reader().getVectorValues(field).search(vector, k, fanout);
int docBase = ctx.docBase;
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
scoreDoc.doc += docBase;
}
}
return TopDocs.merge(k, results);
}
private float checkResults(TopDocs[] results, int[][] nn) {
int totalMatches = 0;
int totalResults = 0;
for (int i = 0; i < results.length; i++) {
int n = results[i].scoreDocs.length;
totalResults += n;
// System.out.println(Arrays.toString(nn[i]));
// System.out.println(Arrays.toString(results[i].scoreDocs));
totalMatches += compareNN(nn[i], results[i]);
}
return totalMatches / (float) totalResults;
}
private int compareNN(int[] expected, TopDocs results) {
int matched = 0;
/*
System.out.print("expected=");
for (int j = 0; j < expected.length; j++) {
System.out.print(expected[j]);
System.out.print(", ");
}
System.out.print('\n');
System.out.println("results=");
for (int j = 0; j < results.scoreDocs.length; j++) {
System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", ");
}
System.out.print('\n');
*/
Set<Integer> expectedSet = new HashSet<>();
for (int i = 0; i < results.scoreDocs.length; i++) {
expectedSet.add(expected[i]);
}
for (ScoreDoc scoreDoc : results.scoreDocs) {
if (expectedSet.contains(scoreDoc.doc)) {
++matched;
}
}
return matched;
}
private int[][] getNN(Path docPath, Path queryPath) throws IOException {
// look in working directory for cached nn file
String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
Path nnPath = Paths.get(nnFileName);
if (Files.exists(nnPath)) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
writeNN(nn, nnPath);
return nn;
}
}
private int[][] readNN(Path nnPath) throws IOException {
int[][] result = new int[numIters][];
try (FileChannel in = FileChannel.open(nnPath)) {
IntBuffer intBuffer =
in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asIntBuffer();
for (int i = 0; i < numIters; i++) {
result[i] = new int[topK];
intBuffer.get(result[i]);
}
}
return result;
}
private void writeNN(int[][] nn, Path nnPath) throws IOException {
if (quiet == false) {
System.out.println("writing true nearest neighbors to " + nnPath);
}
ByteBuffer tmp =
ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
try (OutputStream out = Files.newOutputStream(nnPath)) {
for (int i = 0; i < numIters; i++) {
tmp.asIntBuffer().put(nn[i]);
out.write(tmp.array());
}
}
}
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
int[][] result = new int[numIters][];
if (quiet == false) {
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
}
try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) {
FloatBuffer queries =
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] vector = new float[dim];
float[] query = new float[dim];
for (int i = 0; i < numIters; i++) {
queries.get(query);
long totalBytes = (long) numDocs * dim * Float.BYTES;
int
blockSize =
(int)
Math.min(
totalBytes,
(Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES)),
offset = 0;
int j = 0;
// System.out.println("totalBytes=" + totalBytes);
while (j < numDocs) {
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SEARCH_STRATEGY.compare(query, vector);
queue.insertWithOverflow(j, d);
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
}
}
}
}
return result;
}
private int createIndex(Path docsPath, Path indexPath) throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
iwc.setRAMBufferSizeMB(1994d);
// iwc.setMaxBufferedDocs(10000);
FieldType fieldType =
VectorField.createHnswType(
dim, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, maxConn, beamWidth);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
}
long start = System.nanoTime();
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
int blockSize =
(int)
Math.min(totalBytes, (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES));
float[] vector = new float[dim];
try (FileChannel in = FileChannel.open(docsPath)) {
int i = 0;
while (i < numDocs) {
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
for (; vectors.hasRemaining() && i < numDocs; i++) {
vectors.get(vector);
Document doc = new Document();
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
doc.add(new VectorField(KNN_FIELD, vector, fieldType));
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
}
}
if (quiet == false) {
System.out.println("Done indexing " + numDocs + " documents; now flush");
}
}
}
long elapsed = System.nanoTime() - start;
if (quiet == false) {
System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000_000 + "s");
}
return (int) (elapsed / 1_000_000);
}
private static void usage() {
String error =
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N]";
System.err.println(error);
System.exit(1);
}
class BinaryFileVectors implements RandomAccessVectorValuesProducer, Closeable {
private final int size;
private final FileChannel in;
private final FloatBuffer mmap;
BinaryFileVectors(Path filePath) throws IOException {
in = FileChannel.open(filePath);
long totalBytes = (long) numDocs * dim * Float.BYTES;
if (totalBytes > Integer.MAX_VALUE) {
throw new IllegalArgumentException("input over 2GB not supported");
}
int vectorByteSize = dim * Float.BYTES;
size = (int) (totalBytes / vectorByteSize);
mmap =
in.map(FileChannel.MapMode.READ_ONLY, 0, totalBytes)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
}
@Override
public void close() throws IOException {
in.close();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new Values();
}
class Values implements RandomAccessVectorValues {
float[] vector = new float[dim];
FloatBuffer source = mmap.slice();
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return dim;
}
@Override
public VectorValues.SearchStrategy searchStrategy() {
return SEARCH_STRATEGY;
}
@Override
public float[] vectorValue(int targetOrd) {
int pos = targetOrd * dim;
source.position(pos);
source.get(vector);
return vector;
}
@Override
public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException();
}
}
}
static class NeighborArraySorter extends IntroSorter {
private final int[] node;
private final float[] score;
NeighborArraySorter(NeighborArray neighbors) {
node = neighbors.node;
score = neighbors.score;
}
int pivot;
@Override
protected void swap(int i, int j) {
int tmpNode = node[i];
float tmpScore = score[i];
node[i] = node[j];
score[i] = score[j];
node[j] = tmpNode;
score[j] = tmpScore;
}
@Override
protected void setPivot(int i) {
pivot = i;
}
@Override
protected int comparePivot(int j) {
return Float.compare(score[pivot], score[j]);
}
}
}