| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.lucene.document; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.List; |
| import java.util.PriorityQueue; |
| |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.PointValues; |
| import org.apache.lucene.search.FieldDoc; |
| import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.ScoreDoc; |
| import org.apache.lucene.search.TopFieldDocs; |
| import org.apache.lucene.search.TotalHits; |
| import org.apache.lucene.util.Bits; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.bkd.BKDReader; |
| |
| /** |
| * KNN search on top of N dimensional indexed float points. |
| * |
| * @lucene.experimental |
| */ |
| public class FloatPointNearestNeighbor { |
| |
| static class Cell implements Comparable<Cell> { |
| final int readerIndex; |
| final byte[] minPacked; |
| final byte[] maxPacked; |
| final BKDReader.IndexTree index; |
| /** The closest possible distance^2 of all points in this cell */ |
| final double distanceSquared; |
| |
| Cell(BKDReader.IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSquared) { |
| this.index = index; |
| this.readerIndex = readerIndex; |
| this.minPacked = minPacked.clone(); |
| this.maxPacked = maxPacked.clone(); |
| this.distanceSquared = distanceSquared; |
| } |
| |
| public int compareTo(Cell other) { |
| return Double.compare(distanceSquared, other.distanceSquared); |
| } |
| |
| @Override |
| public String toString() { |
| return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() |
| + " isLeaf=" + index.isLeafNode() + " distanceSquared=" + distanceSquared + ")"; |
| } |
| } |
| |
| private static class NearestVisitor implements PointValues.IntersectVisitor { |
| int curDocBase; |
| Bits curLiveDocs; |
| final int topN; |
| final PriorityQueue<NearestHit> hitQueue; |
| final float[] origin; |
| final private int dims; |
| double bottomNearestDistanceSquared = Double.POSITIVE_INFINITY; |
| int bottomNearestDistanceDoc = Integer.MAX_VALUE; |
| |
| public NearestVisitor(PriorityQueue<NearestHit> hitQueue, int topN, float[] origin) { |
| this.hitQueue = hitQueue; |
| this.topN = topN; |
| this.origin = origin; |
| this.dims = origin.length; |
| } |
| |
| @Override |
| public void visit(int docID) { |
| throw new AssertionError(); |
| } |
| |
| @Override |
| public void visit(int docID, byte[] packedValue) { |
| // System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);; |
| if (curLiveDocs != null && curLiveDocs.get(docID) == false) { |
| return; |
| } |
| |
| double distanceSquared = 0.0d; |
| for (int d = 0, offset = 0 ; d < dims ; ++d, offset += Float.BYTES) { |
| double diff = (double) FloatPoint.decodeDimension(packedValue, offset) - (double) origin[d]; |
| distanceSquared += diff * diff; |
| if (distanceSquared > bottomNearestDistanceSquared) { |
| return; |
| } |
| } |
| |
| // System.out.println(" visit docID=" + docID + " distanceSquared=" + distanceSquared + " value: " + Arrays.toString(docPoint)); |
| |
| int fullDocID = curDocBase + docID; |
| |
| if (hitQueue.size() == topN) { // queue already full |
| if (distanceSquared == bottomNearestDistanceSquared && fullDocID > bottomNearestDistanceDoc) { |
| return; |
| } |
| NearestHit bottom = hitQueue.poll(); |
| // System.out.println(" bottom distanceSquared=" + bottom.distanceSquared); |
| bottom.docID = fullDocID; |
| bottom.distanceSquared = distanceSquared; |
| hitQueue.offer(bottom); |
| updateBottomNearestDistance(); |
| // System.out.println(" ** keep1, now bottom=" + bottom); |
| } else { |
| NearestHit hit = new NearestHit(); |
| hit.docID = fullDocID; |
| hit.distanceSquared = distanceSquared; |
| hitQueue.offer(hit); |
| if (hitQueue.size() == topN) { |
| updateBottomNearestDistance(); |
| } |
| // System.out.println(" ** keep2, new addition=" + hit); |
| } |
| } |
| |
| private void updateBottomNearestDistance() { |
| NearestHit newBottom = hitQueue.peek(); |
| bottomNearestDistanceSquared = newBottom.distanceSquared; |
| bottomNearestDistanceDoc = newBottom.docID; |
| } |
| |
| @Override |
| public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { |
| if (hitQueue.size() == topN && pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin) > bottomNearestDistanceSquared) { |
| return PointValues.Relation.CELL_OUTSIDE_QUERY; |
| } |
| return PointValues.Relation.CELL_CROSSES_QUERY; |
| } |
| } |
| |
| /** Holds one hit from {@link FloatPointNearestNeighbor#nearest} */ |
| static class NearestHit { |
| public int docID; |
| public double distanceSquared; |
| |
| @Override |
| public String toString() { |
| return "NearestHit(docID=" + docID + " distanceSquared=" + distanceSquared + ")"; |
| } |
| } |
| |
| private static NearestHit[] nearest(List<BKDReader> readers, List<Bits> liveDocs, List<Integer> docBases, final int topN, float[] origin) throws IOException { |
| |
| // System.out.println("NEAREST: readers=" + readers + " liveDocs=" + liveDocs + " origin: " + Arrays.toString(origin)); |
| |
| // Holds closest collected points seen so far: |
| // TODO: if we used lucene's PQ we could just updateTop instead of poll/offer: |
| final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(topN, (a, b) -> { |
| // sort by opposite distance natural order |
| int cmp = Double.compare(a.distanceSquared, b.distanceSquared); |
| return cmp != 0 ? -cmp : b.docID - a.docID; // tie-break by higher docID |
| }); |
| |
| // Holds all cells, sorted by closest to the point: |
| PriorityQueue<Cell> cellQueue = new PriorityQueue<>(); |
| |
| NearestVisitor visitor = new NearestVisitor(hitQueue, topN, origin); |
| List<BKDReader.IntersectState> states = new ArrayList<>(); |
| |
| // Add root cell for each reader into the queue: |
| int bytesPerDim = -1; |
| |
| for (int i = 0 ; i < readers.size() ; ++i) { |
| BKDReader reader = readers.get(i); |
| if (bytesPerDim == -1) { |
| bytesPerDim = reader.getBytesPerDimension(); |
| } else if (bytesPerDim != reader.getBytesPerDimension()) { |
| throw new IllegalStateException("bytesPerDim changed from " + bytesPerDim |
| + " to " + reader.getBytesPerDimension() + " across readers"); |
| } |
| byte[] minPackedValue = reader.getMinPackedValue(); |
| byte[] maxPackedValue = reader.getMaxPackedValue(); |
| BKDReader.IntersectState state = reader.getIntersectState(visitor); |
| states.add(state); |
| |
| cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(), |
| pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin))); |
| } |
| |
| while (cellQueue.size() > 0) { |
| Cell cell = cellQueue.poll(); |
| // System.out.println(" visit " + cell); |
| |
| if (cell.distanceSquared > visitor.bottomNearestDistanceSquared) { |
| break; |
| } |
| |
| BKDReader reader = readers.get(cell.readerIndex); |
| if (cell.index.isLeafNode()) { |
| // System.out.println(" leaf"); |
| // Leaf block: visit all points and possibly collect them: |
| visitor.curDocBase = docBases.get(cell.readerIndex); |
| visitor.curLiveDocs = liveDocs.get(cell.readerIndex); |
| reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex)); |
| |
| //assert hitQueue.peek().distanceSquared >= cell.distanceSquared; |
| // System.out.println(" now " + hitQueue.size() + " hits"); |
| } else { |
| // System.out.println(" non-leaf"); |
| // Non-leaf block: split into two cells and put them back into the queue: |
| |
| BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue()); |
| int splitDim = cell.index.getSplitDim(); |
| |
| // we must clone the index so that we we can recurse left and right "concurrently": |
| BKDReader.IndexTree newIndex = cell.index.clone(); |
| byte[] splitPackedValue = cell.maxPacked.clone(); |
| System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim); |
| |
| cell.index.pushLeft(); |
| double distanceLeft = pointToRectangleDistanceSquared(cell.minPacked, splitPackedValue, origin); |
| if (distanceLeft <= visitor.bottomNearestDistanceSquared) { |
| cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, distanceLeft)); |
| } |
| |
| splitPackedValue = cell.minPacked.clone(); |
| System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim); |
| |
| newIndex.pushRight(); |
| double distanceRight = pointToRectangleDistanceSquared(splitPackedValue, cell.maxPacked, origin); |
| if (distanceRight <= visitor.bottomNearestDistanceSquared) { |
| cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, distanceRight)); |
| } |
| } |
| } |
| |
| NearestHit[] hits = new NearestHit[hitQueue.size()]; |
| int downTo = hitQueue.size()-1; |
| while (hitQueue.size() != 0) { |
| hits[downTo] = hitQueue.poll(); |
| downTo--; |
| } |
| //System.out.println(visitor.comp); |
| return hits; |
| } |
| |
| private static double pointToRectangleDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) { |
| double sumOfSquaredDiffs = 0.0d; |
| for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) { |
| double min = FloatPoint.decodeDimension(minPackedValue, offset); |
| if (value[i] < min) { |
| double diff = min - (double)value[i]; |
| sumOfSquaredDiffs += diff * diff; |
| continue; |
| } |
| double max = FloatPoint.decodeDimension(maxPackedValue, offset); |
| if (value[i] > max) { |
| double diff = max - (double)value[i]; |
| sumOfSquaredDiffs += diff * diff; |
| } |
| } |
| return sumOfSquaredDiffs; |
| } |
| |
| public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float... origin) throws IOException { |
| if (topN < 1) { |
| throw new IllegalArgumentException("topN must be at least 1; got " + topN); |
| } |
| if (field == null) { |
| throw new IllegalArgumentException("field must not be null"); |
| } |
| if (searcher == null) { |
| throw new IllegalArgumentException("searcher must not be null"); |
| } |
| List<BKDReader> readers = new ArrayList<>(); |
| List<Integer> docBases = new ArrayList<>(); |
| List<Bits> liveDocs = new ArrayList<>(); |
| int totalHits = 0; |
| for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) { |
| PointValues points = leaf.reader().getPointValues(field); |
| if (points != null) { |
| if (points instanceof BKDReader == false) { |
| throw new IllegalArgumentException("can only run on Lucene60PointsReader points implementation, but got " + points); |
| } |
| totalHits += points.getDocCount(); |
| readers.add((BKDReader)points); |
| docBases.add(leaf.docBase); |
| liveDocs.add(leaf.reader().getLiveDocs()); |
| } |
| } |
| |
| NearestHit[] hits = nearest(readers, liveDocs, docBases, topN, origin); |
| |
| // Convert to TopFieldDocs: |
| ScoreDoc[] scoreDocs = new ScoreDoc[hits.length]; |
| for(int i=0;i<hits.length;i++) { |
| NearestHit hit = hits[i]; |
| scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] { (float)Math.sqrt(hit.distanceSquared) }); |
| } |
| return new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, null); |
| } |
| } |