blob: fb2ecd097b9b3b5d4403c260d1aa1cf92a4a5e14 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.RamUsageEstimator;
/**
* Hierarchical NSW graph that provides efficient approximate nearest neighbor search for high dimensional vectors.
* This isn't thread-safe.
* See <a href="https://arxiv.org/abs/1603.09320">this paper</a> for details.
*/
public final class HNSWGraph implements Accountable {
private final VectorValues.DistanceFunction distFunc;
private final List<Layer> layers;
private boolean frozen = false;
private long bytesUsed;
public HNSWGraph(VectorValues.DistanceFunction distFunc) {
this.distFunc = distFunc;
this.layers = new ArrayList<>();
}
/**
* Searches the nearest neighbors for a specified query at a level.
* @param query search query vector
* @param results on entry, has enter points to this level. On exit, the nearest neighbors in this level
* @param ef the number of nodes to be searched
* @param level graph level
* @param vectorValues vector values
* @return number of candidates visited
*/
int searchLayer(float[] query, FurthestNeighbors results, int ef, int level, VectorValues vectorValues) throws IOException {
if (level >= layers.size()) {
throw new IllegalArgumentException("layer does not exist for the level: " + level);
}
Layer layer = layers.get(level);
TreeSet<Neighbor> candidates = new TreeSet<>();
// set of docids that have been visited by search on this layer, used to avoid backtracking
Set<Integer> visited = new HashSet<>();
for (Neighbor n : results) {
candidates.add(n);
visited.add(n.docId());
}
Neighbor f = results.top();
while (candidates.size() > 0) {
Neighbor c = candidates.pollFirst();
assert !c.isDeferred();
assert !f.isDeferred();
if (c.distance() > f.distance() && results.size() >= ef) {
break;
}
for (Neighbor e : layer.getFriends(c.docId())) {
if (visited.contains(e.docId())) {
continue;
}
visited.add(e.docId());
float dist = distance(query, e.docId(), vectorValues);
if (dist < f.distance() || results.size() < ef) {
if (results.size() == ef) {
results.pop();
}
Neighbor n = new ImmutableNeighbor(e.docId(), dist);
candidates.add(n);
results.insertWithOverflow(n);
f = results.top();
}
}
}
//System.out.println("level=" + level + ", visited nodes=" + visited.size());
//return pickNearestNeighbor(results);
return visited.size();
}
private float distance(float[] query, int docId, VectorValues vectorValues) throws IOException {
if (!vectorValues.seek(docId)) {
throw new IllegalStateException("docId=" + docId + " has no vector value");
}
float[] other = vectorValues.vectorValue();
return VectorValues.distance(query, other, distFunc);
}
static NearestNeighbors pickNearestNeighbor(Neighbors queue) {
NearestNeighbors nearests = new NearestNeighbors(queue.size());
Set<Integer> addedDocs = new HashSet<>();
int ef = queue.size();
while (addedDocs.size() < ef && queue.size() > 0) {
Neighbor c = queue.pop();
if (!addedDocs.contains(c.docId())) {
nearests.add(c);
addedDocs.add(c.docId());
}
}
return nearests;
}
public void ensureLevel(int level) {
if (frozen) {
throw new IllegalStateException("graph is already freezed!");
}
if (level < 0) {
throw new IllegalArgumentException("level must be a positive integer: " + level);
}
for (int l = layers.size(); l <= level; l++) {
layers.add(new Layer(l));
}
}
public int topLevel() {
return layers.size() - 1;
}
public boolean isEmpty() {
return layers.isEmpty() || layers.get(0).getNodes().isEmpty();
}
public int getFirstEnterPoint() {
if (layers.isEmpty()) {
throw new IllegalStateException("the graph has no layers!");
}
List<Integer> nodesAtMaxLevel = layers.get(layers.size() - 1).getNodes();
if (nodesAtMaxLevel.isEmpty()) {
throw new IllegalStateException("the max level of this graph is empty!");
}
return nodesAtMaxLevel.get(0);
}
public List<Integer> getEnterPoints() {
if (layers.isEmpty()) {
throw new IllegalStateException("the graph has no layers!");
}
List<Integer> nodesAtMaxLevel = layers.get(layers.size() - 1).getNodes();
if (nodesAtMaxLevel.isEmpty()) {
throw new IllegalStateException("the max level of this graph is empty!");
}
return List.copyOf(nodesAtMaxLevel);
}
public boolean hasNodes(int level) {
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
return layer.getNodes().size() > 0;
}
public IntsRef getFriends(int level, int node) {
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
int[] friends = layer.getFriends(node).stream().mapToInt(Neighbor::docId).sorted().toArray();
return new IntsRef(friends, 0, friends.length);
}
public boolean hasFriends(int level, int node) {
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
return layer.getFriends(node) != Layer.NO_FRIENDS;
}
void addNode(int level, int node) {
if (frozen) {
throw new IllegalStateException("graph is already freezed!");
}
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
layer.addNodeIfAbsent(node);
}
/** Connects two nodes; this is supposed to be called when indexing */
public void connectNodes(int level, int node1, int node2, float dist, int maxConnections) {
if (frozen) {
throw new IllegalStateException("graph is already freezed!");
}
assert level >= 0;
assert node1 >= 0 && node2 >= 0;
assert node1 != node2;
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
layer.connectNodes(node1, node2, dist);
// ensure friends size <= maxConnections
if (maxConnections > 0) {
layer.shrink(node2, maxConnections);
}
}
/** Connects two nodes; this is supposed to be called when searching */
public void connectNodes(int level, int node1, int node2) {
if (frozen) {
throw new IllegalStateException("graph is already freezed!");
}
assert level >= 0;
assert node1 >= 0 && node2 >= 0;
assert node1 != node2;
Layer layer = layers.get(level);
if (layer == null) {
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
layer.connectNodes(node1, node2);
}
public void finish() {
while (layers.isEmpty() == false && layers.get(layers.size() - 1).size() == 0) {
// remove empty top layers
layers.remove(layers.size() - 1);
}
this.frozen = true;
}
@Override
public long ramBytesUsed() {
if (bytesUsed == 0) {
bytesUsed = RamUsageEstimator.sizeOfCollection(layers);
}
return bytesUsed;
}
}