blob: ec3ddeeb6735c668d8bb7efd5fad6fc1d30fba22 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.bkd;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FutureArrays;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.MSBRadixSorter;
import org.apache.lucene.util.RadixSelector;
import org.apache.lucene.util.Selector;
import org.apache.lucene.util.Sorter;
/**
*
* Offline Radix selector for BKD tree.
*
* @lucene.internal
* */
public final class BKDRadixSelector {
// size of the histogram
private static final int HISTOGRAM_SIZE = 256;
// size of the online buffer: 8 KB
private static final int MAX_SIZE_OFFLINE_BUFFER = 1024 * 8;
// histogram array
private final long[] histogram;
// number of bytes to be sorted: config.bytesPerDim + Integer.BYTES
private final int bytesSorted;
// flag to when we are moving to sort on heap
private final int maxPointsSortInHeap;
// reusable buffer
private final byte[] offlineBuffer;
// holder for partition points
private final int[] partitionBucket;
// scratch array to hold temporary data
private final byte[] scratch;
// Directory to create new Offline writer
private final Directory tempDir;
// prefix for temp files
private final String tempFileNamePrefix;
// BKD tree configuration
private final BKDConfig config;
/**
* Sole constructor.
*/
public BKDRadixSelector(BKDConfig config, int maxPointsSortInHeap, Directory tempDir, String tempFileNamePrefix) {
this.config = config;
this.maxPointsSortInHeap = maxPointsSortInHeap;
this.tempDir = tempDir;
this.tempFileNamePrefix = tempFileNamePrefix;
// Selection and sorting is done in a given dimension. In case the value of the dimension are equal
// between two points we tie break first using the data-only dimensions and if those are still equal
// we tie-break on the docID. Here we account for all bytes used in the process.
this.bytesSorted = config.bytesPerDim + (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
final int numberOfPointsOffline = MAX_SIZE_OFFLINE_BUFFER / config.bytesPerDoc;
this.offlineBuffer = new byte[numberOfPointsOffline * config.bytesPerDoc];
this.partitionBucket = new int[bytesSorted];
this.histogram = new long[HISTOGRAM_SIZE];
this.scratch = new byte[bytesSorted];
}
/**
* It uses the provided {@code points} from the given {@code from} to the given {@code to}
* to populate the {@code partitionSlices} array holder (length > 1) with two path slices
* so the path slice at position 0 contains {@code partition - from} points
* where the value of the {@code dim} is lower or equal to the {@code to -from}
* points on the slice at position 1.
*
* The {@code dimCommonPrefix} provides a hint for the length of the common prefix length for
* the {@code dim} where are partitioning the points.
*
* It return the value of the {@code dim} at the partition point.
*
* If the provided {@code points} is wrapping an {@link OfflinePointWriter}, the
* writer is destroyed in the process to save disk space.
*/
public byte[] select(PathSlice points, PathSlice[] partitionSlices, long from, long to, long partitionPoint, int dim, int dimCommonPrefix) throws IOException {
checkArgs(from, to, partitionPoint);
assert partitionSlices.length > 1 : "[partition alices] must be > 1, got " + partitionSlices.length;
// If we are on heap then we just select on heap
if (points.writer instanceof HeapPointWriter) {
byte[] partition = heapRadixSelect((HeapPointWriter) points.writer, dim, Math.toIntExact(from), Math.toIntExact(to), Math.toIntExact(partitionPoint), dimCommonPrefix);
partitionSlices[0] = new PathSlice(points.writer, from, partitionPoint - from);
partitionSlices[1] = new PathSlice(points.writer, partitionPoint, to - partitionPoint);
return partition;
}
OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer;
try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim);
PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) {
partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from);
partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint);
return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, dimCommonPrefix, dim);
}
}
void checkArgs(long from, long to, long partitionPoint) {
if (partitionPoint < from) {
throw new IllegalArgumentException("partitionPoint must be >= from");
}
if (partitionPoint >= to) {
throw new IllegalArgumentException("partitionPoint must be < to");
}
}
private int findCommonPrefixAndHistogram(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{
// find common prefix
int commonPrefixPosition = bytesSorted;
final int offset = dim * config.bytesPerDim;
try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
assert commonPrefixPosition > dimCommonPrefix;
reader.next();
PointValue pointValue = reader.pointValue();
BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
// copy dimension
System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + offset, scratch, 0, config.bytesPerDim);
// copy data dimensions and docID
System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + config.packedIndexBytesLength, scratch, config.bytesPerDim, (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES);
for (long i = from + 1; i < to; i++) {
reader.next();
pointValue = reader.pointValue();
if (commonPrefixPosition == dimCommonPrefix) {
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
// we do not need to check for common prefix anymore,
// just finish the histogram and break
for (long j = i + 1; j < to; j++) {
reader.next();
pointValue = reader.pointValue();
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
}
break;
} else {
// Check common prefix and adjust histogram
final int startIndex = (dimCommonPrefix > config.bytesPerDim) ? config.bytesPerDim : dimCommonPrefix;
final int endIndex = (commonPrefixPosition > config.bytesPerDim) ? config.bytesPerDim : commonPrefixPosition;
packedValueDocID = pointValue.packedValueDocIDBytes();
int j = FutureArrays.mismatch(scratch, startIndex, endIndex, packedValueDocID.bytes, packedValueDocID.offset + offset + startIndex, packedValueDocID.offset + offset + endIndex);
if (j == -1) {
if (commonPrefixPosition > config.bytesPerDim) {
// Tie-break on data dimensions + docID
final int startTieBreak = config.packedIndexBytesLength;
final int endTieBreak = startTieBreak + commonPrefixPosition - config.bytesPerDim;
int k = FutureArrays.mismatch(scratch, config.bytesPerDim, commonPrefixPosition,
packedValueDocID.bytes, packedValueDocID.offset + startTieBreak, packedValueDocID.offset + endTieBreak);
if (k != -1) {
commonPrefixPosition = config.bytesPerDim + k;
Arrays.fill(histogram, 0);
histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
}
}
} else {
commonPrefixPosition = dimCommonPrefix + j;
Arrays.fill(histogram, 0);
histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
}
if (commonPrefixPosition != bytesSorted) {
histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
}
}
}
}
// Build partition buckets up to commonPrefix
for (int i = 0; i < commonPrefixPosition; i++) {
partitionBucket[i] = scratch[i] & 0xff;
}
return commonPrefixPosition;
}
private int getBucket(int offset, int commonPrefixPosition, PointValue pointValue) {
int bucket;
if (commonPrefixPosition < config.bytesPerDim) {
BytesRef packedValue = pointValue.packedValue();
bucket = packedValue.bytes[packedValue.offset + offset + commonPrefixPosition] & 0xff;
} else {
BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
bucket = packedValueDocID.bytes[packedValueDocID.offset + config.packedIndexBytesLength + commonPrefixPosition - config.bytesPerDim] & 0xff;
}
return bucket;
}
private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right,
long from, long to, long partitionPoint, int iteration, int baseCommonPrefix, int dim) throws IOException {
// Find common prefix from baseCommonPrefix and build histogram
int commonPrefix = findCommonPrefixAndHistogram(points, from, to, dim, baseCommonPrefix);
// If all equals we just partition the points
if (commonPrefix == bytesSorted) {
offlinePartition(points, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint);
return partitionPointFromCommonPrefix();
}
long leftCount = 0;
long rightCount = 0;
// Count left points and record the partition point
for(int i = 0; i < HISTOGRAM_SIZE; i++) {
long size = histogram[i];
if (leftCount + size > partitionPoint - from) {
partitionBucket[commonPrefix] = i;
break;
}
leftCount += size;
}
// Count right points
for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) {
rightCount += histogram[i];
}
long delta = histogram[partitionBucket[commonPrefix]];
assert leftCount + rightCount + delta == to - from : (leftCount + rightCount + delta) + " / " + (to - from);
// Special case when points are equal except last byte, we can just tie-break
if (commonPrefix == bytesSorted - 1) {
long tieBreakCount =(partitionPoint - from - leftCount);
offlinePartition(points, left, right, null, from, to, dim, commonPrefix, tieBreakCount);
return partitionPointFromCommonPrefix();
}
// Create the delta points writer
PointWriter deltaPoints;
try (PointWriter tempDeltaPoints = getDeltaPointWriter(left, right, delta, iteration)) {
// Divide the points. This actually destroys the current writer
offlinePartition(points, left, right, tempDeltaPoints, from, to, dim, commonPrefix, 0);
deltaPoints = tempDeltaPoints;
}
long newPartitionPoint = partitionPoint - from - leftCount;
if (deltaPoints instanceof HeapPointWriter) {
return heapPartition((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
} else {
return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim);
}
}
private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
assert bytePosition == bytesSorted -1 || deltaPoints != null;
int offset = dim * config.bytesPerDim;
long tiebreakCounter = 0;
try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
while (reader.next()) {
PointValue pointValue = reader.pointValue();
int bucket = getBucket(offset, bytePosition, pointValue);
if (bucket < this.partitionBucket[bytePosition]) {
// to the left side
left.append(pointValue);
} else if (bucket > this.partitionBucket[bytePosition]) {
// to the right side
right.append(pointValue);
} else {
if (bytePosition == bytesSorted - 1) {
if (tiebreakCounter < numDocsTiebreak) {
left.append(pointValue);
tiebreakCounter++;
} else {
right.append(pointValue);
}
} else {
deltaPoints.append(pointValue);
}
}
}
}
// Delete original file
points.destroy();
}
private byte[] partitionPointFromCommonPrefix() {
byte[] partition = new byte[config.bytesPerDim];
for (int i = 0; i < config.bytesPerDim; i++) {
partition[i] = (byte)partitionBucket[i];
}
return partition;
}
private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix);
for (int i = from; i < to; i++) {
PointValue value = points.getPackedValueSlice(i);
if (i < partitionPoint) {
left.append(value);
} else {
right.append(value);
}
}
return partition;
}
private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefixLength) {
final int dimOffset = dim * config.bytesPerDim + commonPrefixLength;
final int dimCmpBytes = config.bytesPerDim - commonPrefixLength;
final int dataOffset = config.packedIndexBytesLength - dimCmpBytes;
new RadixSelector(bytesSorted - commonPrefixLength) {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
}
@Override
protected Selector getFallbackSelector(int d) {
final int skypedBytes = d + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim + skypedBytes;
final int dimEnd = dim * config.bytesPerDim + config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength = (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
return new IntroSelector() {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) {
System.arraycopy(points.block, i * config.bytesPerDoc + dim * config.bytesPerDim, scratch, 0, config.bytesPerDim);
}
System.arraycopy(points.block, i * config.bytesPerDoc + config.packedIndexBytesLength, scratch, config.bytesPerDim, dataLength);
}
@Override
protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc;
int jOffset = j * config.bytesPerDoc;
int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
}
@Override
protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc;
int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, config.bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return FutureArrays.compareUnsigned(scratch, config.bytesPerDim, config.bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
}
};
}
}.select(from, to, partitionPoint);
byte[] partition = new byte[config.bytesPerDim];
PointValue pointValue = points.getPackedValueSlice(partitionPoint);
BytesRef packedValue = pointValue.packedValue();
System.arraycopy(packedValue.bytes, packedValue.offset + dim * config.bytesPerDim, partition, 0, config.bytesPerDim);
return partition;
}
/** Sort the heap writer by the specified dim. It is used to sort the leaves of the tree */
public void heapRadixSort(final HeapPointWriter points, int from, int to, int dim, int commonPrefixLength) {
final int dimOffset = dim * config.bytesPerDim + commonPrefixLength;
final int dimCmpBytes = config.bytesPerDim - commonPrefixLength;
final int dataOffset = config.packedIndexBytesLength - dimCmpBytes;
new MSBRadixSorter(bytesSorted - commonPrefixLength) {
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
}
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected Sorter getFallbackSorter(int k) {
final int skypedBytes = k + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim + skypedBytes;
final int dimEnd = dim * config.bytesPerDim + config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength = (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
return new IntroSorter() {
@Override
protected void swap(int i, int j) {
points.swap(i, j);
}
@Override
protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) {
System.arraycopy(points.block, i * config.bytesPerDoc + dim * config.bytesPerDim, scratch, 0, config.bytesPerDim);
}
System.arraycopy(points.block, i * config.bytesPerDoc + config.packedIndexBytesLength, scratch, config.bytesPerDim, dataLength);
}
@Override
protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc;
int jOffset = j * config.bytesPerDoc;
int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
}
@Override
protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc;
int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, config.bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return FutureArrays.compareUnsigned(scratch, config.bytesPerDim, config.bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
}
};
}
}.sort(from, to);
}
private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException {
if (delta <= getMaxPointsSortInHeap(left, right)) {
return new HeapPointWriter(config, Math.toIntExact(delta));
} else {
return new OfflinePointWriter(config, tempDir, tempFileNamePrefix, "delta" + iteration, delta);
}
}
private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) {
int pointsUsed = 0;
if (left instanceof HeapPointWriter) {
pointsUsed += ((HeapPointWriter) left).size;
}
if (right instanceof HeapPointWriter) {
pointsUsed += ((HeapPointWriter) right).size;
}
assert maxPointsSortInHeap >= pointsUsed;
return maxPointsSortInHeap - pointsUsed;
}
PointWriter getPointWriter(long count, String desc) throws IOException {
// As we recurse, we hold two on-heap point writers at any point. Therefore the
// max size for these objects is half of the total points we can have on-heap.
if (count <= maxPointsSortInHeap / 2) {
int size = Math.toIntExact(count);
return new HeapPointWriter(config, size);
} else {
return new OfflinePointWriter(config, tempDir, tempFileNamePrefix, desc, count);
}
}
/** Sliced reference to points in an PointWriter. */
public static final class PathSlice {
public final PointWriter writer;
public final long start;
public final long count;
public PathSlice(PointWriter writer, long start, long count) {
this.writer = writer;
this.start = start;
this.count = count;
}
@Override
public String toString() {
return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
}
}
}