| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.lucene.util.hnsw; |
| |
| import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; |
| |
| import java.io.Closeable; |
| import java.io.IOException; |
| import java.io.OutputStream; |
| import java.lang.management.ManagementFactory; |
| import java.lang.management.ThreadMXBean; |
| import java.nio.ByteBuffer; |
| import java.nio.ByteOrder; |
| import java.nio.FloatBuffer; |
| import java.nio.IntBuffer; |
| import java.nio.channels.FileChannel; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.nio.file.Paths; |
| import java.nio.file.attribute.FileTime; |
| import java.util.Arrays; |
| import java.util.HashSet; |
| import java.util.Locale; |
| import java.util.Objects; |
| import java.util.Set; |
| import org.apache.lucene.codecs.KnnVectorsFormat; |
| import org.apache.lucene.codecs.KnnVectorsReader; |
| import org.apache.lucene.codecs.lucene94.Lucene94Codec; |
| import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat; |
| import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsReader; |
| import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; |
| import org.apache.lucene.document.Document; |
| import org.apache.lucene.document.FieldType; |
| import org.apache.lucene.document.KnnVectorField; |
| import org.apache.lucene.document.StoredField; |
| import org.apache.lucene.index.CodecReader; |
| import org.apache.lucene.index.DirectoryReader; |
| import org.apache.lucene.index.IndexWriter; |
| import org.apache.lucene.index.IndexWriterConfig; |
| import org.apache.lucene.index.LeafReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.RandomAccessVectorValues; |
| import org.apache.lucene.index.RandomAccessVectorValuesProducer; |
| import org.apache.lucene.index.VectorEncoding; |
| import org.apache.lucene.index.VectorSimilarityFunction; |
| import org.apache.lucene.search.ConstantScoreScorer; |
| import org.apache.lucene.search.ConstantScoreWeight; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.KnnVectorQuery; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.search.QueryVisitor; |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.ScoreMode; |
| import org.apache.lucene.search.Scorer; |
| import org.apache.lucene.search.TopDocs; |
| import org.apache.lucene.search.Weight; |
| import org.apache.lucene.store.Directory; |
| import org.apache.lucene.store.FSDirectory; |
| import org.apache.lucene.util.BitSetIterator; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.FixedBitSet; |
| import org.apache.lucene.util.IntroSorter; |
| import org.apache.lucene.util.PrintStreamInfoStream; |
| import org.apache.lucene.util.SuppressForbidden; |
| |
| /** |
| * For testing indexing and search performance of a knn-graph |
| * |
| * <p>java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search |
| * .../vectors.bin |
| */ |
| public class KnnGraphTester { |
| |
| private static final String KNN_FIELD = "knn"; |
| private static final String ID_FIELD = "id"; |
| |
| private int numDocs; |
| private int dim; |
| private int topK; |
| private int numIters; |
| private int fanout; |
| private Path indexPath; |
| private boolean quiet; |
| private boolean reindex; |
| private boolean forceMerge; |
| private int reindexTimeMsec; |
| private int beamWidth; |
| private int maxConn; |
| private VectorSimilarityFunction similarityFunction; |
| private VectorEncoding vectorEncoding; |
| private FixedBitSet matchDocs; |
| private float selectivity; |
| private boolean prefilter; |
| |
| private KnnGraphTester() { |
| // set defaults |
| numDocs = 1000; |
| numIters = 1000; |
| dim = 256; |
| topK = 100; |
| fanout = topK; |
| similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; |
| vectorEncoding = VectorEncoding.FLOAT32; |
| selectivity = 1f; |
| prefilter = false; |
| } |
| |
| public static void main(String... args) throws Exception { |
| new KnnGraphTester().run(args); |
| } |
| |
| private void run(String... args) throws Exception { |
| String operation = null; |
| Path docVectorsPath = null, queryPath = null, outputPath = null; |
| for (int iarg = 0; iarg < args.length; iarg++) { |
| String arg = args[iarg]; |
| switch (arg) { |
| case "-search": |
| case "-check": |
| case "-stats": |
| case "-dump": |
| if (operation != null) { |
| throw new IllegalArgumentException( |
| "Specify only one operation, not both " + arg + " and " + operation); |
| } |
| operation = arg; |
| if (operation.equals("-search")) { |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException( |
| "Operation " + arg + " requires a following pathname"); |
| } |
| queryPath = Paths.get(args[++iarg]); |
| } |
| break; |
| case "-fanout": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-fanout requires a following number"); |
| } |
| fanout = Integer.parseInt(args[++iarg]); |
| break; |
| case "-beamWidthIndex": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-beamWidthIndex requires a following number"); |
| } |
| beamWidth = Integer.parseInt(args[++iarg]); |
| break; |
| case "-maxConn": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-maxConn requires a following number"); |
| } |
| maxConn = Integer.parseInt(args[++iarg]); |
| break; |
| case "-dim": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-dim requires a following number"); |
| } |
| dim = Integer.parseInt(args[++iarg]); |
| break; |
| case "-ndoc": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-ndoc requires a following number"); |
| } |
| numDocs = Integer.parseInt(args[++iarg]); |
| break; |
| case "-niter": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-niter requires a following number"); |
| } |
| numIters = Integer.parseInt(args[++iarg]); |
| break; |
| case "-reindex": |
| reindex = true; |
| break; |
| case "-topK": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-topK requires a following number"); |
| } |
| topK = Integer.parseInt(args[++iarg]); |
| break; |
| case "-out": |
| outputPath = Paths.get(args[++iarg]); |
| break; |
| case "-docs": |
| docVectorsPath = Paths.get(args[++iarg]); |
| break; |
| case "-encoding": |
| String encoding = args[++iarg]; |
| switch (encoding) { |
| case "byte": |
| vectorEncoding = VectorEncoding.BYTE; |
| break; |
| case "float32": |
| vectorEncoding = VectorEncoding.FLOAT32; |
| break; |
| default: |
| throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only"); |
| } |
| break; |
| case "-metric": |
| String metric = args[++iarg]; |
| switch (metric) { |
| case "euclidean": |
| similarityFunction = VectorSimilarityFunction.EUCLIDEAN; |
| break; |
| case "angular": |
| similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; |
| break; |
| default: |
| throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only"); |
| } |
| break; |
| case "-forceMerge": |
| forceMerge = true; |
| break; |
| case "-prefilter": |
| prefilter = true; |
| break; |
| case "-filterSelectivity": |
| if (iarg == args.length - 1) { |
| throw new IllegalArgumentException("-filterSelectivity requires a following float"); |
| } |
| selectivity = Float.parseFloat(args[++iarg]); |
| if (selectivity <= 0 || selectivity >= 1) { |
| throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1"); |
| } |
| break; |
| case "-quiet": |
| quiet = true; |
| break; |
| default: |
| throw new IllegalArgumentException("unknown argument " + arg); |
| // usage(); |
| } |
| } |
| if (operation == null && reindex == false) { |
| usage(); |
| } |
| if (prefilter && selectivity == 1f) { |
| throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1"); |
| } |
| indexPath = Paths.get(formatIndexPath(docVectorsPath)); |
| if (reindex) { |
| if (docVectorsPath == null) { |
| throw new IllegalArgumentException("-docs argument is required when indexing"); |
| } |
| reindexTimeMsec = createIndex(docVectorsPath, indexPath); |
| if (forceMerge) { |
| forceMerge(); |
| } |
| } |
| if (operation != null) { |
| switch (operation) { |
| case "-search": |
| if (docVectorsPath == null) { |
| throw new IllegalArgumentException("missing -docs arg"); |
| } |
| if (selectivity < 1) { |
| matchDocs = generateRandomBitSet(numDocs, selectivity); |
| } |
| if (outputPath != null) { |
| testSearch(indexPath, queryPath, outputPath, null); |
| } else { |
| testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath)); |
| } |
| break; |
| case "-dump": |
| dumpGraph(docVectorsPath); |
| break; |
| case "-stats": |
| printFanoutHist(indexPath); |
| break; |
| } |
| } |
| } |
| |
| private String formatIndexPath(Path docsPath) { |
| return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index"; |
| } |
| |
| @SuppressForbidden(reason = "Prints stuff") |
| private void printFanoutHist(Path indexPath) throws IOException { |
| try (Directory dir = FSDirectory.open(indexPath); |
| DirectoryReader reader = DirectoryReader.open(dir)) { |
| for (LeafReaderContext context : reader.leaves()) { |
| LeafReader leafReader = context.reader(); |
| KnnVectorsReader vectorsReader = |
| ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader()) |
| .getFieldReader(KNN_FIELD); |
| HnswGraph knnValues = ((Lucene94HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD); |
| System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc()); |
| printGraphFanout(knnValues, leafReader.maxDoc()); |
| } |
| } |
| } |
| |
| @SuppressWarnings("unchecked") |
| private void dumpGraph(Path docsPath) throws IOException { |
| try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) { |
| RandomAccessVectorValues values = vectors.randomAccess(); |
| HnswGraphBuilder<float[]> builder = |
| (HnswGraphBuilder<float[]>) |
| HnswGraphBuilder.create( |
| vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0); |
| // start at node 1 |
| for (int i = 1; i < numDocs; i++) { |
| builder.addGraphNode(i, values); |
| System.out.println("\nITERATION " + i); |
| dumpGraph(builder.hnsw); |
| } |
| } |
| } |
| |
| private void dumpGraph(OnHeapHnswGraph hnsw) { |
| for (int i = 0; i < hnsw.size(); i++) { |
| NeighborArray neighbors = hnsw.getNeighbors(0, i); |
| System.out.printf(Locale.ROOT, "%5d", i); |
| NeighborArray sorted = new NeighborArray(neighbors.size(), true); |
| for (int j = 0; j < neighbors.size(); j++) { |
| int node = neighbors.node[j]; |
| float score = neighbors.score[j]; |
| sorted.add(node, score); |
| } |
| new NeighborArraySorter(sorted).sort(0, sorted.size()); |
| for (int j = 0; j < sorted.size(); j++) { |
| System.out.printf(Locale.ROOT, " [%d, %.4f]", sorted.node[j], sorted.score[j]); |
| } |
| System.out.println(); |
| } |
| } |
| |
| @SuppressForbidden(reason = "Prints stuff") |
| private void forceMerge() throws IOException { |
| IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND); |
| iwc.setInfoStream(new PrintStreamInfoStream(System.out)); |
| System.out.println("Force merge index in " + indexPath); |
| try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) { |
| iw.forceMerge(1); |
| } |
| } |
| |
| @SuppressForbidden(reason = "Prints stuff") |
| private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException { |
| int min = Integer.MAX_VALUE, max = 0, total = 0; |
| int count = 0; |
| int[] leafHist = new int[numDocs]; |
| for (int node = 0; node < numDocs; node++) { |
| knnValues.seek(0, node); |
| int n = 0; |
| while (knnValues.nextNeighbor() != NO_MORE_DOCS) { |
| ++n; |
| } |
| ++leafHist[n]; |
| max = Math.max(max, n); |
| min = Math.min(min, n); |
| if (n > 0) { |
| ++count; |
| total += n; |
| } |
| } |
| System.out.printf( |
| "Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n", |
| count, min, total / (float) count, max); |
| printHist(leafHist, max, count, 10); |
| } |
| |
| @SuppressForbidden(reason = "Prints stuff") |
| private void printHist(int[] hist, int max, int count, int nbuckets) { |
| System.out.print("%"); |
| for (int i = 0; i <= nbuckets; i++) { |
| System.out.printf("%4d", i * 100 / nbuckets); |
| } |
| System.out.printf("\n %4d", hist[0]); |
| int total = 0, ibucket = 1; |
| for (int i = 1; i <= max && ibucket <= nbuckets; i++) { |
| total += hist[i]; |
| while (total >= count * ibucket / nbuckets) { |
| System.out.printf("%4d", i); |
| ++ibucket; |
| } |
| } |
| System.out.println(); |
| } |
| |
| @SuppressForbidden(reason = "Prints stuff") |
| private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn) |
| throws IOException { |
| TopDocs[] results = new TopDocs[numIters]; |
| long elapsed, totalCpuTime, totalVisited = 0; |
| try (FileChannel input = FileChannel.open(queryPath)) { |
| VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding, numIters); |
| if (quiet == false) { |
| System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); |
| } |
| long start; |
| ThreadMXBean bean = ManagementFactory.getThreadMXBean(); |
| long cpuTimeStartNs; |
| try (Directory dir = FSDirectory.open(indexPath); |
| DirectoryReader reader = DirectoryReader.open(dir)) { |
| IndexSearcher searcher = new IndexSearcher(reader); |
| numDocs = reader.maxDoc(); |
| Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null; |
| for (int i = 0; i < numIters; i++) { |
| // warm up |
| float[] target = targetReader.next(); |
| if (prefilter) { |
| doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); |
| } else { |
| doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); |
| } |
| } |
| targetReader.reset(); |
| start = System.nanoTime(); |
| cpuTimeStartNs = bean.getCurrentThreadCpuTime(); |
| for (int i = 0; i < numIters; i++) { |
| float[] target = targetReader.next(); |
| if (prefilter) { |
| results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); |
| } else { |
| results[i] = |
| doKnnVectorQuery( |
| searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); |
| |
| if (matchDocs != null) { |
| results[i].scoreDocs = |
| Arrays.stream(results[i].scoreDocs) |
| .filter(scoreDoc -> matchDocs.get(scoreDoc.doc)) |
| .toArray(ScoreDoc[]::new); |
| } |
| } |
| } |
| totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000; |
| elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms |
| for (int i = 0; i < numIters; i++) { |
| totalVisited += results[i].totalHits.value; |
| for (ScoreDoc doc : results[i].scoreDocs) { |
| if (doc.doc != NO_MORE_DOCS) { |
| // there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens |
| // in some degenerate case (like input query has NaN in it?) that causes no results to |
| // be returned from HNSW search? |
| doc.doc = Integer.parseInt(reader.document(doc.doc).get("id")); |
| } else { |
| System.out.println("NO_MORE_DOCS!"); |
| } |
| } |
| } |
| } |
| if (quiet == false) { |
| System.out.println( |
| "completed " |
| + numIters |
| + " searches in " |
| + elapsed |
| + " ms: " |
| + ((1000 * numIters) / elapsed) |
| + " QPS " |
| + "CPU time=" |
| + totalCpuTime |
| + "ms"); |
| } |
| } |
| if (outputPath != null) { |
| ByteBuffer buf = ByteBuffer.allocate(4); |
| IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer(); |
| try (OutputStream out = Files.newOutputStream(outputPath)) { |
| for (int i = 0; i < numIters; i++) { |
| for (ScoreDoc doc : results[i].scoreDocs) { |
| ibuf.position(0); |
| ibuf.put(doc.doc); |
| out.write(buf.array()); |
| } |
| } |
| } |
| } else { |
| if (quiet == false) { |
| System.out.println("checking results"); |
| } |
| float recall = checkResults(results, nn); |
| totalVisited /= numIters; |
| System.out.printf( |
| Locale.ROOT, |
| "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n", |
| recall, |
| totalCpuTime / (float) numIters, |
| numDocs, |
| fanout, |
| maxConn, |
| beamWidth, |
| totalVisited, |
| reindexTimeMsec, |
| selectivity, |
| prefilter ? "pre-filter" : "post-filter"); |
| } |
| } |
| |
| private abstract static class VectorReader { |
| final float[] target; |
| final ByteBuffer bytes; |
| |
| static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n) |
| throws IOException { |
| int bufferSize = n * dim * vectorEncoding.byteSize; |
| return switch (vectorEncoding) { |
| case BYTE -> new VectorReaderByte(input, dim, bufferSize); |
| case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize); |
| }; |
| } |
| |
| VectorReader(FileChannel input, int dim, int bufferSize) throws IOException { |
| bytes = |
| input.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize).order(ByteOrder.LITTLE_ENDIAN); |
| target = new float[dim]; |
| } |
| |
| void reset() { |
| bytes.position(0); |
| } |
| |
| abstract float[] next(); |
| } |
| |
| private static class VectorReaderFloat32 extends VectorReader { |
| private final FloatBuffer floats; |
| |
| VectorReaderFloat32(FileChannel input, int dim, int bufferSize) throws IOException { |
| super(input, dim, bufferSize); |
| floats = bytes.asFloatBuffer(); |
| } |
| |
| @Override |
| void reset() { |
| super.reset(); |
| floats.position(0); |
| } |
| |
| @Override |
| float[] next() { |
| floats.get(target); |
| return target; |
| } |
| } |
| |
| private static class VectorReaderByte extends VectorReader { |
| private byte[] scratch; |
| private BytesRef bytesRef; |
| |
| VectorReaderByte(FileChannel input, int dim, int bufferSize) throws IOException { |
| super(input, dim, bufferSize); |
| scratch = new byte[dim]; |
| bytesRef = new BytesRef(scratch); |
| } |
| |
| @Override |
| float[] next() { |
| bytes.get(scratch); |
| for (int i = 0; i < scratch.length; i++) { |
| target[i] = scratch[i]; |
| } |
| return target; |
| } |
| |
| BytesRef nextBytes() { |
| bytes.get(scratch); |
| return bytesRef; |
| } |
| } |
| |
| private static TopDocs doKnnVectorQuery( |
| IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter) |
| throws IOException { |
| return searcher.search(new KnnVectorQuery(field, vector, k + fanout, filter), k); |
| } |
| |
| private float checkResults(TopDocs[] results, int[][] nn) { |
| int totalMatches = 0; |
| int totalResults = results.length * topK; |
| for (int i = 0; i < results.length; i++) { |
| // System.out.println(Arrays.toString(nn[i])); |
| // System.out.println(Arrays.toString(results[i].scoreDocs)); |
| totalMatches += compareNN(nn[i], results[i]); |
| } |
| return totalMatches / (float) totalResults; |
| } |
| |
| private int compareNN(int[] expected, TopDocs results) { |
| int matched = 0; |
| /* |
| System.out.print("expected="); |
| for (int j = 0; j < expected.length; j++) { |
| System.out.print(expected[j]); |
| System.out.print(", "); |
| } |
| System.out.print('\n'); |
| System.out.println("results="); |
| for (int j = 0; j < results.scoreDocs.length; j++) { |
| System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", "); |
| } |
| System.out.print('\n'); |
| */ |
| Set<Integer> expectedSet = new HashSet<>(); |
| for (int i = 0; i < topK; i++) { |
| expectedSet.add(expected[i]); |
| } |
| for (ScoreDoc scoreDoc : results.scoreDocs) { |
| if (expectedSet.contains(scoreDoc.doc)) { |
| ++matched; |
| } |
| } |
| return matched; |
| } |
| |
| private int[][] getNN(Path docPath, Path queryPath) throws IOException { |
| // look in working directory for cached nn file |
| String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36); |
| String nnFileName = "nn-" + hash + ".bin"; |
| Path nnPath = Paths.get(nnFileName); |
| if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) { |
| return readNN(nnPath); |
| } else { |
| // TODO: enable computing NN from high precision vectors when |
| // checking low-precision recall |
| int[][] nn = computeNN(docPath, queryPath, vectorEncoding); |
| if (selectivity == 1f) { |
| writeNN(nn, nnPath); |
| } |
| return nn; |
| } |
| } |
| |
| private boolean isNewer(Path path, Path... others) throws IOException { |
| FileTime modified = Files.getLastModifiedTime(path); |
| for (Path other : others) { |
| if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| private int[][] readNN(Path nnPath) throws IOException { |
| int[][] result = new int[numIters][]; |
| try (FileChannel in = FileChannel.open(nnPath)) { |
| IntBuffer intBuffer = |
| in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES) |
| .order(ByteOrder.LITTLE_ENDIAN) |
| .asIntBuffer(); |
| for (int i = 0; i < numIters; i++) { |
| result[i] = new int[topK]; |
| intBuffer.get(result[i]); |
| } |
| } |
| return result; |
| } |
| |
| private void writeNN(int[][] nn, Path nnPath) throws IOException { |
| if (quiet == false) { |
| System.out.println("writing true nearest neighbors to " + nnPath); |
| } |
| ByteBuffer tmp = |
| ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); |
| try (OutputStream out = Files.newOutputStream(nnPath)) { |
| for (int i = 0; i < numIters; i++) { |
| tmp.asIntBuffer().put(nn[i]); |
| out.write(tmp.array()); |
| } |
| } |
| } |
| |
| @SuppressForbidden(reason = "Uses random()") |
| private static FixedBitSet generateRandomBitSet(int size, float selectivity) { |
| FixedBitSet bitSet = new FixedBitSet(size); |
| for (int i = 0; i < size; i++) { |
| if (Math.random() < selectivity) { |
| bitSet.set(i); |
| } else { |
| bitSet.clear(i); |
| } |
| } |
| return bitSet; |
| } |
| |
| private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding) |
| throws IOException { |
| int[][] result = new int[numIters][]; |
| if (quiet == false) { |
| System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); |
| } |
| try (FileChannel in = FileChannel.open(docPath); |
| FileChannel qIn = FileChannel.open(queryPath)) { |
| VectorReader docReader = VectorReader.create(in, dim, encoding, numDocs); |
| VectorReader queryReader = VectorReader.create(qIn, dim, encoding, numIters); |
| for (int i = 0; i < numIters; i++) { |
| float[] query = queryReader.next(); |
| NeighborQueue queue = new NeighborQueue(topK, false); |
| for (int j = 0; j < numDocs; j++) { |
| float[] doc = docReader.next(); |
| float d = similarityFunction.compare(query, doc); |
| if (matchDocs == null || matchDocs.get(j)) { |
| queue.insertWithOverflow(j, d); |
| } |
| } |
| docReader.reset(); |
| result[i] = new int[topK]; |
| for (int k = topK - 1; k >= 0; k--) { |
| result[i][k] = queue.topNode(); |
| queue.pop(); |
| // System.out.print(" " + n); |
| } |
| if (quiet == false && (i + 1) % 10 == 0) { |
| System.out.print(" " + (i + 1)); |
| System.out.flush(); |
| } |
| } |
| } |
| return result; |
| } |
| |
| private int createIndex(Path docsPath, Path indexPath) throws IOException { |
| IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE); |
| iwc.setCodec( |
| new Lucene94Codec() { |
| @Override |
| public KnnVectorsFormat getKnnVectorsFormatForField(String field) { |
| return new Lucene94HnswVectorsFormat(maxConn, beamWidth); |
| } |
| }); |
| // iwc.setMergePolicy(NoMergePolicy.INSTANCE); |
| iwc.setRAMBufferSizeMB(1994d); |
| iwc.setUseCompoundFile(false); |
| // iwc.setMaxBufferedDocs(10000); |
| |
| FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction); |
| if (quiet == false) { |
| iwc.setInfoStream(new PrintStreamInfoStream(System.out)); |
| System.out.println("creating index in " + indexPath); |
| } |
| long start = System.nanoTime(); |
| try (FSDirectory dir = FSDirectory.open(indexPath); |
| IndexWriter iw = new IndexWriter(dir, iwc)) { |
| try (FileChannel in = FileChannel.open(docsPath)) { |
| VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding, numDocs); |
| for (int i = 0; i < numDocs; i++) { |
| Document doc = new Document(); |
| switch (vectorEncoding) { |
| case BYTE -> doc.add( |
| new KnnVectorField( |
| KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); |
| case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType)); |
| } |
| doc.add(new StoredField(ID_FIELD, i)); |
| iw.addDocument(doc); |
| } |
| if (quiet == false) { |
| System.out.println("Done indexing " + numDocs + " documents; now flush"); |
| } |
| } |
| } |
| long elapsed = System.nanoTime() - start; |
| if (quiet == false) { |
| System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000_000 + "s"); |
| } |
| return (int) (elapsed / 1_000_000); |
| } |
| |
| private static void usage() { |
| String error = |
| "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]"; |
| System.err.println(error); |
| System.exit(1); |
| } |
| |
| class BinaryFileVectors implements RandomAccessVectorValuesProducer, Closeable { |
| |
| private final int size; |
| private final FileChannel in; |
| private final FloatBuffer mmap; |
| |
| BinaryFileVectors(Path filePath) throws IOException { |
| in = FileChannel.open(filePath); |
| long totalBytes = (long) numDocs * dim * Float.BYTES; |
| if (totalBytes > Integer.MAX_VALUE) { |
| throw new IllegalArgumentException("input over 2GB not supported"); |
| } |
| int vectorByteSize = dim * Float.BYTES; |
| size = (int) (totalBytes / vectorByteSize); |
| mmap = |
| in.map(FileChannel.MapMode.READ_ONLY, 0, totalBytes) |
| .order(ByteOrder.LITTLE_ENDIAN) |
| .asFloatBuffer(); |
| } |
| |
| @Override |
| public void close() throws IOException { |
| in.close(); |
| } |
| |
| @Override |
| public RandomAccessVectorValues randomAccess() { |
| return new Values(); |
| } |
| |
| class Values implements RandomAccessVectorValues { |
| |
| float[] vector = new float[dim]; |
| FloatBuffer source = mmap.slice(); |
| |
| @Override |
| public int size() { |
| return size; |
| } |
| |
| @Override |
| public int dimension() { |
| return dim; |
| } |
| |
| @Override |
| public float[] vectorValue(int targetOrd) { |
| int pos = targetOrd * dim; |
| source.position(pos); |
| source.get(vector); |
| return vector; |
| } |
| |
| @Override |
| public BytesRef binaryValue(int targetOrd) { |
| throw new UnsupportedOperationException(); |
| } |
| } |
| } |
| |
| static class NeighborArraySorter extends IntroSorter { |
| private final int[] node; |
| private final float[] score; |
| |
| NeighborArraySorter(NeighborArray neighbors) { |
| node = neighbors.node; |
| score = neighbors.score; |
| } |
| |
| int pivot; |
| |
| @Override |
| protected void swap(int i, int j) { |
| int tmpNode = node[i]; |
| float tmpScore = score[i]; |
| node[i] = node[j]; |
| score[i] = score[j]; |
| node[j] = tmpNode; |
| score[j] = tmpScore; |
| } |
| |
| @Override |
| protected void setPivot(int i) { |
| pivot = i; |
| } |
| |
| @Override |
| protected int comparePivot(int j) { |
| return Float.compare(score[pivot], score[j]); |
| } |
| } |
| |
| private static class BitSetQuery extends Query { |
| |
| private final FixedBitSet docs; |
| private final int cardinality; |
| |
| BitSetQuery(FixedBitSet docs) { |
| this.docs = docs; |
| this.cardinality = docs.cardinality(); |
| } |
| |
| @Override |
| public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) |
| throws IOException { |
| return new ConstantScoreWeight(this, boost) { |
| @Override |
| public Scorer scorer(LeafReaderContext context) throws IOException { |
| return new ConstantScoreScorer( |
| this, score(), scoreMode, new BitSetIterator(docs, cardinality)); |
| } |
| |
| @Override |
| public boolean isCacheable(LeafReaderContext ctx) { |
| return false; |
| } |
| }; |
| } |
| |
| @Override |
| public void visit(QueryVisitor visitor) {} |
| |
| @Override |
| public String toString(String field) { |
| return "BitSetQuery"; |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return sameClassAs(other) && docs.equals(((BitSetQuery) other).docs); |
| } |
| |
| @Override |
| public int hashCode() { |
| return 31 * classHash() + docs.hashCode(); |
| } |
| } |
| } |