blob: 16af10e2cd02e0f18ff6b30fecbba5415b0c0def [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.bkd;
import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.IntFunction;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.MutablePointTree;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.TrackingDirectoryWrapper;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.bkd.BKDUtil.ByteArrayPredicate;
// TODO
// - allow variable length byte[] (across docs and dims), but this is quite a bit more hairy
// - we could also index "auto-prefix terms" here, and use better compression, and maybe only use
// for the "fully contained" case so we'd
// only index docIDs
// - the index could be efficiently encoded as an FST, so we don't have wasteful
// (monotonic) long[] leafBlockFPs; or we could use MonotonicLongValues ... but then
// the index is already plenty small: 60M OSM points --> 1.1 MB with 128 points
// per leaf, and you can reduce that by putting more points per leaf
// - we could use threads while building; the higher nodes are very parallelizable
/**
* Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller and
* smaller N-dim rectangles (cells) until the number of points in a given rectangle is &lt;= <code>
* config.maxPointsInLeafNode</code>. The tree is partially balanced, which means the leaf nodes
* will have the requested <code>config.maxPointsInLeafNode</code> except one that might have less.
* Leaf nodes may straddle the two bottom levels of the binary tree. Values that fall exactly on a
* cell boundary may be in either cell.
*
* <p>The number of dimensions can be 1 to 8, but every byte[] value is fixed length.
*
* <p>This consumes heap during writing: it allocates a <code>Long[numLeaves]</code>, a <code>
* byte[numLeaves*(1+config.bytesPerDim)]</code> and then uses up to the specified {@code
* maxMBSortInHeap} heap space for writing.
*
* <p><b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>config.maxPointsInLeafNode
* </code> / config.bytesPerDim total points.
*
* @lucene.experimental
*/
public class BKDWriter implements Closeable {
public static final String CODEC_NAME = "BKD";
public static final int VERSION_START = 4; // version used by Lucene 7.0
// public static final int VERSION_CURRENT = VERSION_START;
public static final int VERSION_LEAF_STORES_BOUNDS = 5;
public static final int VERSION_SELECTIVE_INDEXING = 6;
public static final int VERSION_LOW_CARDINALITY_LEAVES = 7;
public static final int VERSION_META_FILE = 9;
public static final int VERSION_CURRENT = VERSION_META_FILE;
/** Number of splits before we compute the exact bounding box of an inner node. */
private static final int SPLITS_BEFORE_EXACT_BOUNDS = 4;
/** Default maximum heap to use, before spilling to (slower) disk */
public static final float DEFAULT_MAX_MB_SORT_IN_HEAP = 16.0f;
/** BKD tree configuration */
protected final BKDConfig config;
private final ByteArrayComparator comparator;
private final ByteArrayPredicate equalsPredicate;
private final ByteArrayComparator commonPrefixComparator;
final TrackingDirectoryWrapper tempDir;
final String tempFileNamePrefix;
final double maxMBSortInHeap;
final byte[] scratchDiff;
final byte[] scratch1;
final byte[] scratch2;
final BytesRef scratchBytesRef1 = new BytesRef();
final BytesRef scratchBytesRef2 = new BytesRef();
final int[] commonPrefixLengths;
protected final FixedBitSet docsSeen;
private PointWriter pointWriter;
private boolean finished;
private IndexOutput tempInput;
private final int maxPointsSortInHeap;
/** Minimum per-dim values, packed */
protected final byte[] minPackedValue;
/** Maximum per-dim values, packed */
protected final byte[] maxPackedValue;
protected long pointCount;
/** An upper bound on how many points the caller will add (includes deletions) */
private final long totalPointCount;
private final int maxDoc;
private final DocIdsWriter docIdsWriter;
public BKDWriter(
int maxDoc,
Directory tempDir,
String tempFileNamePrefix,
BKDConfig config,
double maxMBSortInHeap,
long totalPointCount) {
verifyParams(maxMBSortInHeap, totalPointCount);
// We use tracking dir to deal with removing files on exception, so each place that
// creates temp files doesn't need crazy try/finally/sucess logic:
this.tempDir = new TrackingDirectoryWrapper(tempDir);
this.tempFileNamePrefix = tempFileNamePrefix;
this.maxMBSortInHeap = maxMBSortInHeap;
this.totalPointCount = totalPointCount;
this.maxDoc = maxDoc;
this.config = config;
this.comparator = ArrayUtil.getUnsignedComparator(config.bytesPerDim);
this.equalsPredicate = BKDUtil.getEqualsPredicate(config.bytesPerDim);
this.commonPrefixComparator = BKDUtil.getPrefixLengthComparator(config.bytesPerDim);
docsSeen = new FixedBitSet(maxDoc);
scratchDiff = new byte[config.bytesPerDim];
scratch1 = new byte[config.packedBytesLength];
scratch2 = new byte[config.packedBytesLength];
commonPrefixLengths = new int[config.numDims];
minPackedValue = new byte[config.packedIndexBytesLength];
maxPackedValue = new byte[config.packedIndexBytesLength];
// Maximum number of points we hold in memory at any time
maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (config.bytesPerDoc));
docIdsWriter = new DocIdsWriter(config.maxPointsInLeafNode);
// Finally, we must be able to hold at least the leaf node in heap during build:
if (maxPointsSortInHeap < config.maxPointsInLeafNode) {
throw new IllegalArgumentException(
"maxMBSortInHeap="
+ maxMBSortInHeap
+ " only allows for maxPointsSortInHeap="
+ maxPointsSortInHeap
+ ", but this is less than maxPointsInLeafNode="
+ config.maxPointsInLeafNode
+ "; "
+ "either increase maxMBSortInHeap or decrease maxPointsInLeafNode");
}
}
private static void verifyParams(double maxMBSortInHeap, long totalPointCount) {
if (maxMBSortInHeap < 0.0) {
throw new IllegalArgumentException(
"maxMBSortInHeap must be >= 0.0 (got: " + maxMBSortInHeap + ")");
}
if (totalPointCount < 0) {
throw new IllegalArgumentException(
"totalPointCount must be >=0 (got: " + totalPointCount + ")");
}
}
private void initPointWriter() throws IOException {
assert pointWriter == null : "Point writer is already initialized";
// Total point count is an estimation but the final point count must be equal or lower to that
// number.
if (totalPointCount > maxPointsSortInHeap) {
pointWriter = new OfflinePointWriter(config, tempDir, tempFileNamePrefix, "spill", 0);
tempInput = ((OfflinePointWriter) pointWriter).out;
} else {
pointWriter = new HeapPointWriter(config, Math.toIntExact(totalPointCount));
}
}
public void add(byte[] packedValue, int docID) throws IOException {
if (packedValue.length != config.packedBytesLength) {
throw new IllegalArgumentException(
"packedValue should be length="
+ config.packedBytesLength
+ " (got: "
+ packedValue.length
+ ")");
}
if (pointCount >= totalPointCount) {
throw new IllegalStateException(
"totalPointCount="
+ totalPointCount
+ " was passed when we were created, but we just hit "
+ (pointCount + 1)
+ " values");
}
if (pointCount == 0) {
initPointWriter();
System.arraycopy(packedValue, 0, minPackedValue, 0, config.packedIndexBytesLength);
System.arraycopy(packedValue, 0, maxPackedValue, 0, config.packedIndexBytesLength);
} else {
for (int dim = 0; dim < config.numIndexDims; dim++) {
int offset = dim * config.bytesPerDim;
if (comparator.compare(packedValue, offset, minPackedValue, offset) < 0) {
System.arraycopy(packedValue, offset, minPackedValue, offset, config.bytesPerDim);
} else if (comparator.compare(packedValue, offset, maxPackedValue, offset) > 0) {
System.arraycopy(packedValue, offset, maxPackedValue, offset, config.bytesPerDim);
}
}
}
pointWriter.append(packedValue, docID);
pointCount++;
docsSeen.set(docID);
}
private static class MergeReader {
private final PointValues.PointTree pointTree;
private final int packedBytesLength;
private final MergeState.DocMap docMap;
private final MergeIntersectsVisitor mergeIntersectsVisitor;
/** Which doc in this block we are up to */
private int docBlockUpto;
/** Current doc ID */
public int docID;
/** Current packed value */
public final byte[] packedValue;
public MergeReader(PointValues pointValues, MergeState.DocMap docMap) throws IOException {
this.packedBytesLength = pointValues.getBytesPerDimension() * pointValues.getNumDimensions();
this.pointTree = pointValues.getPointTree();
this.mergeIntersectsVisitor = new MergeIntersectsVisitor(packedBytesLength);
// move to first child of the tree and collect docs
while (pointTree.moveToChild()) {}
pointTree.visitDocValues(mergeIntersectsVisitor);
this.docMap = docMap;
this.packedValue = new byte[packedBytesLength];
}
public boolean next() throws IOException {
// System.out.println("MR.next this=" + this);
while (true) {
if (docBlockUpto == mergeIntersectsVisitor.docsInBlock) {
if (collectNextLeaf() == false) {
assert mergeIntersectsVisitor.docsInBlock == 0;
return false;
}
assert mergeIntersectsVisitor.docsInBlock > 0;
docBlockUpto = 0;
}
final int index = docBlockUpto++;
int oldDocID = mergeIntersectsVisitor.docIDs[index];
int mappedDocID;
if (docMap == null) {
mappedDocID = oldDocID;
} else {
mappedDocID = docMap.get(oldDocID);
}
if (mappedDocID != -1) {
// Not deleted!
docID = mappedDocID;
System.arraycopy(
mergeIntersectsVisitor.packedValues,
index * packedBytesLength,
packedValue,
0,
packedBytesLength);
return true;
}
}
}
private boolean collectNextLeaf() throws IOException {
assert pointTree.moveToChild() == false;
mergeIntersectsVisitor.reset();
do {
if (pointTree.moveToSibling()) {
// move to first child of this node and collect docs
while (pointTree.moveToChild()) {}
pointTree.visitDocValues(mergeIntersectsVisitor);
return true;
}
} while (pointTree.moveToParent());
return false;
}
}
private static class MergeIntersectsVisitor implements IntersectVisitor {
int docsInBlock = 0;
byte[] packedValues;
int[] docIDs;
private final int packedBytesLength;
MergeIntersectsVisitor(int packedBytesLength) {
this.docIDs = new int[0];
this.packedValues = new byte[0];
this.packedBytesLength = packedBytesLength;
}
void reset() {
docsInBlock = 0;
}
@Override
public void grow(int count) {
assert docsInBlock == 0;
if (docIDs.length < count) {
docIDs = ArrayUtil.grow(docIDs, count);
int packedValuesSize = Math.toIntExact(docIDs.length * (long) packedBytesLength);
if (packedValuesSize > ArrayUtil.MAX_ARRAY_LENGTH) {
throw new IllegalStateException(
"array length must be <= to "
+ ArrayUtil.MAX_ARRAY_LENGTH
+ " but was: "
+ packedValuesSize);
}
packedValues = ArrayUtil.growExact(packedValues, packedValuesSize);
}
}
@Override
public void visit(int docID) {
throw new UnsupportedOperationException();
}
@Override
public void visit(int docID, byte[] packedValue) {
System.arraycopy(
packedValue, 0, packedValues, docsInBlock * packedBytesLength, packedBytesLength);
docIDs[docsInBlock++] = docID;
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_CROSSES_QUERY;
}
}
private static class BKDMergeQueue extends PriorityQueue<MergeReader> {
private final int bytesPerDim;
public BKDMergeQueue(int bytesPerDim, int maxSize) {
super(maxSize);
this.bytesPerDim = bytesPerDim;
}
@Override
public boolean lessThan(MergeReader a, MergeReader b) {
assert a != b;
int cmp =
Arrays.compareUnsigned(a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
if (cmp < 0) {
return true;
} else if (cmp > 0) {
return false;
}
// Tie break by sorting smaller docIDs earlier:
return a.docID < b.docID;
}
}
/** flat representation of a kd-tree */
private interface BKDTreeLeafNodes {
/** number of leaf nodes */
int numLeaves();
/**
* pointer to the leaf node previously written. Leaves are order from left to right, so leaf at
* {@code index} 0 is the leftmost leaf and the the leaf at {@code numleaves()} -1 is the
* rightmost leaf
*/
long getLeafLP(int index);
/**
* split value between two leaves. The split value at position n corresponds to the leaves at (n
* -1) and n.
*/
BytesRef getSplitValue(int index);
/**
* split dimension between two leaves. The split dimension at position n corresponds to the
* leaves at (n -1) and n.
*/
int getSplitDimension(int index);
}
/**
* Write a field from a {@link MutablePointTree}. This way of writing points is faster than
* regular writes with {@link BKDWriter#add} since there is opportunity for reordering points
* before writing them to disk. This method does not use transient disk in order to reorder
* points.
*/
public Runnable writeField(
IndexOutput metaOut,
IndexOutput indexOut,
IndexOutput dataOut,
String fieldName,
MutablePointTree reader)
throws IOException {
if (config.numDims == 1) {
return writeField1Dim(metaOut, indexOut, dataOut, fieldName, reader);
} else {
return writeFieldNDims(metaOut, indexOut, dataOut, fieldName, reader);
}
}
private void computePackedValueBounds(
MutablePointTree values,
int from,
int to,
byte[] minPackedValue,
byte[] maxPackedValue,
BytesRef scratch) {
if (from == to) {
return;
}
values.getValue(from, scratch);
System.arraycopy(
scratch.bytes, scratch.offset, minPackedValue, 0, config.packedIndexBytesLength);
System.arraycopy(
scratch.bytes, scratch.offset, maxPackedValue, 0, config.packedIndexBytesLength);
for (int i = from + 1; i < to; ++i) {
values.getValue(i, scratch);
for (int dim = 0; dim < config.numIndexDims; dim++) {
final int startOffset = dim * config.bytesPerDim;
final int endOffset = startOffset + config.bytesPerDim;
if (Arrays.compareUnsigned(
scratch.bytes,
scratch.offset + startOffset,
scratch.offset + endOffset,
minPackedValue,
startOffset,
endOffset)
< 0) {
System.arraycopy(
scratch.bytes,
scratch.offset + startOffset,
minPackedValue,
startOffset,
config.bytesPerDim);
} else if (Arrays.compareUnsigned(
scratch.bytes,
scratch.offset + startOffset,
scratch.offset + endOffset,
maxPackedValue,
startOffset,
endOffset)
> 0) {
System.arraycopy(
scratch.bytes,
scratch.offset + startOffset,
maxPackedValue,
startOffset,
config.bytesPerDim);
}
}
}
}
/* In the 2+D case, we recursively pick the split dimension, compute the
* median value and partition other values around it. */
private Runnable writeFieldNDims(
IndexOutput metaOut,
IndexOutput indexOut,
IndexOutput dataOut,
String fieldName,
MutablePointTree values)
throws IOException {
if (pointCount != 0) {
throw new IllegalStateException("cannot mix add and writeField");
}
// Catch user silliness:
if (finished == true) {
throw new IllegalStateException("already finished");
}
// Mark that we already finished:
finished = true;
pointCount = values.size();
final int numLeaves =
Math.toIntExact((pointCount + config.maxPointsInLeafNode - 1) / config.maxPointsInLeafNode);
final int numSplits = numLeaves - 1;
checkMaxLeafNodeCount(numLeaves);
final byte[] splitPackedValues = new byte[numSplits * config.bytesPerDim];
final byte[] splitDimensionValues = new byte[numSplits];
final long[] leafBlockFPs = new long[numLeaves];
// compute the min/max for this slice
computePackedValueBounds(
values, 0, Math.toIntExact(pointCount), minPackedValue, maxPackedValue, scratchBytesRef1);
for (int i = 0; i < Math.toIntExact(pointCount); ++i) {
docsSeen.set(values.getDocID(i));
}
final long dataStartFP = dataOut.getFilePointer();
final int[] parentSplits = new int[config.numIndexDims];
build(
0,
numLeaves,
values,
0,
Math.toIntExact(pointCount),
dataOut,
minPackedValue.clone(),
maxPackedValue.clone(),
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
new int[config.maxPointsInLeafNode]);
assert Arrays.equals(parentSplits, new int[config.numIndexDims]);
scratchBytesRef1.length = config.bytesPerDim;
scratchBytesRef1.bytes = splitPackedValues;
BKDTreeLeafNodes leafNodes =
new BKDTreeLeafNodes() {
@Override
public long getLeafLP(int index) {
return leafBlockFPs[index];
}
@Override
public BytesRef getSplitValue(int index) {
scratchBytesRef1.offset = index * config.bytesPerDim;
return scratchBytesRef1;
}
@Override
public int getSplitDimension(int index) {
return splitDimensionValues[index] & 0xff;
}
@Override
public int numLeaves() {
return leafBlockFPs.length;
}
};
return () -> {
try {
writeIndex(metaOut, indexOut, config.maxPointsInLeafNode, leafNodes, dataStartFP);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
/* In the 1D case, we can simply sort points in ascending order and use the
* same writing logic as we use at merge time. */
private Runnable writeField1Dim(
IndexOutput metaOut,
IndexOutput indexOut,
IndexOutput dataOut,
String fieldName,
MutablePointTree reader)
throws IOException {
MutablePointTreeReaderUtils.sort(config, maxDoc, reader, 0, Math.toIntExact(reader.size()));
final OneDimensionBKDWriter oneDimWriter =
new OneDimensionBKDWriter(metaOut, indexOut, dataOut);
reader.visitDocValues(
new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {
oneDimWriter.add(packedValue, docID);
}
@Override
public void visit(int docID) {
throw new IllegalStateException();
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_CROSSES_QUERY;
}
});
return oneDimWriter.finish();
}
/**
* More efficient bulk-add for incoming {@link PointValues}s. This does a merge sort of the
* already sorted values and currently only works when numDims==1. This returns -1 if all
* documents containing dimensional values were deleted.
*/
public Runnable merge(
IndexOutput metaOut,
IndexOutput indexOut,
IndexOutput dataOut,
List<MergeState.DocMap> docMaps,
List<PointValues> readers)
throws IOException {
assert docMaps == null || readers.size() == docMaps.size();
BKDMergeQueue queue = new BKDMergeQueue(config.bytesPerDim, readers.size());
for (int i = 0; i < readers.size(); i++) {
PointValues pointValues = readers.get(i);
assert pointValues.getNumDimensions() == config.numDims
&& pointValues.getBytesPerDimension() == config.bytesPerDim
&& pointValues.getNumIndexDimensions() == config.numIndexDims;
MergeState.DocMap docMap;
if (docMaps == null) {
docMap = null;
} else {
docMap = docMaps.get(i);
}
MergeReader reader = new MergeReader(pointValues, docMap);
if (reader.next()) {
queue.add(reader);
}
}
OneDimensionBKDWriter oneDimWriter = new OneDimensionBKDWriter(metaOut, indexOut, dataOut);
while (queue.size() != 0) {
MergeReader reader = queue.top();
// System.out.println("iter reader=" + reader);
oneDimWriter.add(reader.packedValue, reader.docID);
if (reader.next()) {
queue.updateTop();
} else {
// This segment was exhausted
queue.pop();
}
}
return oneDimWriter.finish();
}
private class OneDimensionBKDWriter {
final IndexOutput metaOut, indexOut, dataOut;
final long dataStartFP;
final List<Long> leafBlockFPs = new ArrayList<>();
final List<byte[]> leafBlockStartValues = new ArrayList<>();
final byte[] leafValues = new byte[config.maxPointsInLeafNode * config.packedBytesLength];
final int[] leafDocs = new int[config.maxPointsInLeafNode];
private long valueCount;
private int leafCount;
private int leafCardinality;
OneDimensionBKDWriter(IndexOutput metaOut, IndexOutput indexOut, IndexOutput dataOut) {
if (config.numIndexDims != 1) {
throw new UnsupportedOperationException(
"config.numIndexDims must be 1 but got " + config.numIndexDims);
}
if (pointCount != 0) {
throw new IllegalStateException("cannot mix add and merge");
}
// Catch user silliness:
if (finished == true) {
throw new IllegalStateException("already finished");
}
// Mark that we already finished:
finished = true;
this.metaOut = metaOut;
this.indexOut = indexOut;
this.dataOut = dataOut;
this.dataStartFP = dataOut.getFilePointer();
lastPackedValue = new byte[config.packedBytesLength];
}
// for asserts
final byte[] lastPackedValue;
private int lastDocID;
void add(byte[] packedValue, int docID) throws IOException {
assert valueInOrder(
config, valueCount + leafCount, 0, lastPackedValue, packedValue, 0, docID, lastDocID);
if (leafCount == 0
|| equalsPredicate.test(leafValues, (leafCount - 1) * config.bytesPerDim, packedValue, 0)
== false) {
leafCardinality++;
}
System.arraycopy(
packedValue,
0,
leafValues,
leafCount * config.packedBytesLength,
config.packedBytesLength);
leafDocs[leafCount] = docID;
docsSeen.set(docID);
leafCount++;
if (valueCount + leafCount > totalPointCount) {
throw new IllegalStateException(
"totalPointCount="
+ totalPointCount
+ " was passed when we were created, but we just hit "
+ (valueCount + leafCount)
+ " values");
}
if (leafCount == config.maxPointsInLeafNode) {
// We write a block once we hit exactly the max count ... this is different from
// when we write N > 1 dimensional points where we write between max/2 and max per leaf
// block
writeLeafBlock(leafCardinality);
leafCardinality = 0;
leafCount = 0;
}
assert (lastDocID = docID) >= 0; // only assign when asserts are enabled
}
public Runnable finish() throws IOException {
if (leafCount > 0) {
writeLeafBlock(leafCardinality);
leafCardinality = 0;
leafCount = 0;
}
if (valueCount == 0) {
return null;
}
pointCount = valueCount;
scratchBytesRef1.length = config.bytesPerDim;
scratchBytesRef1.offset = 0;
assert leafBlockStartValues.size() + 1 == leafBlockFPs.size();
BKDTreeLeafNodes leafNodes =
new BKDTreeLeafNodes() {
@Override
public long getLeafLP(int index) {
return leafBlockFPs.get(index);
}
@Override
public BytesRef getSplitValue(int index) {
scratchBytesRef1.bytes = leafBlockStartValues.get(index);
return scratchBytesRef1;
}
@Override
public int getSplitDimension(int index) {
return 0;
}
@Override
public int numLeaves() {
return leafBlockFPs.size();
}
};
return () -> {
try {
writeIndex(metaOut, indexOut, config.maxPointsInLeafNode, leafNodes, dataStartFP);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
private void writeLeafBlock(int leafCardinality) throws IOException {
assert leafCount != 0;
if (valueCount == 0) {
System.arraycopy(leafValues, 0, minPackedValue, 0, config.packedIndexBytesLength);
}
System.arraycopy(
leafValues,
(leafCount - 1) * config.packedBytesLength,
maxPackedValue,
0,
config.packedIndexBytesLength);
valueCount += leafCount;
if (leafBlockFPs.size() > 0) {
// Save the first (minimum) value in each leaf block except the first, to build the split
// value index in the end:
leafBlockStartValues.add(ArrayUtil.copyOfSubArray(leafValues, 0, config.packedBytesLength));
}
leafBlockFPs.add(dataOut.getFilePointer());
checkMaxLeafNodeCount(leafBlockFPs.size());
// Find per-dim common prefix:
commonPrefixLengths[0] =
commonPrefixComparator.compare(
leafValues, 0, leafValues, (leafCount - 1) * config.packedBytesLength);
writeLeafBlockDocs(dataOut, leafDocs, 0, leafCount);
writeCommonPrefixes(dataOut, commonPrefixLengths, leafValues);
scratchBytesRef1.length = config.packedBytesLength;
scratchBytesRef1.bytes = leafValues;
final IntFunction<BytesRef> packedValues =
new IntFunction<BytesRef>() {
@Override
public BytesRef apply(int i) {
scratchBytesRef1.offset = config.packedBytesLength * i;
return scratchBytesRef1;
}
};
assert valuesInOrderAndBounds(
config,
leafCount,
0,
ArrayUtil.copyOfSubArray(leafValues, 0, config.packedBytesLength),
ArrayUtil.copyOfSubArray(
leafValues,
(leafCount - 1) * config.packedBytesLength,
leafCount * config.packedBytesLength),
packedValues,
leafDocs,
0);
writeLeafBlockPackedValues(
dataOut, commonPrefixLengths, leafCount, 0, packedValues, leafCardinality);
}
}
private int getNumLeftLeafNodes(int numLeaves) {
assert numLeaves > 1 : "getNumLeftLeaveNodes() called with " + numLeaves;
// return the level that can be filled with this number of leaves
int lastFullLevel = 31 - Integer.numberOfLeadingZeros(numLeaves);
// how many leaf nodes are in the full level
int leavesFullLevel = 1 << lastFullLevel;
// half of the leaf nodes from the full level goes to the left
int numLeftLeafNodes = leavesFullLevel / 2;
// leaf nodes that do not fit in the full level
int unbalancedLeafNodes = numLeaves - leavesFullLevel;
// distribute unbalanced leaf nodes
numLeftLeafNodes += Math.min(unbalancedLeafNodes, numLeftLeafNodes);
// we should always place unbalanced leaf nodes on the left
assert numLeftLeafNodes >= numLeaves - numLeftLeafNodes
&& numLeftLeafNodes <= 2L * (numLeaves - numLeftLeafNodes);
return numLeftLeafNodes;
}
// TODO: if we fixed each partition step to just record the file offset at the "split point", we
// could probably handle variable length
// encoding and not have our own ByteSequencesReader/Writer
// useful for debugging:
/*
private void printPathSlice(String desc, PathSlice slice, int dim) throws IOException {
System.out.println(" " + desc + " dim=" + dim + " count=" + slice.count + ":");
try(PointReader r = slice.writer.getReader(slice.start, slice.count)) {
int count = 0;
while (r.next()) {
byte[] v = r.packedValue();
System.out.println(" " + count + ": " + new BytesRef(v, dim*config.bytesPerDim, config.bytesPerDim));
count++;
if (count == slice.count) {
break;
}
}
}
}
*/
private void checkMaxLeafNodeCount(int numLeaves) {
if (config.bytesPerDim * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) {
throw new IllegalStateException(
"too many nodes; increase config.maxPointsInLeafNode (currently "
+ config.maxPointsInLeafNode
+ ") and reindex");
}
}
/**
* Writes the BKD tree to the provided {@link IndexOutput}s and returns a {@link Runnable} that
* writes the index of the tree if at least one point has been added, or {@code null} otherwise.
*/
public Runnable finish(IndexOutput metaOut, IndexOutput indexOut, IndexOutput dataOut)
throws IOException {
// System.out.println("\nBKDTreeWriter.finish pointCount=" + pointCount + " out=" + out + "
// heapWriter=" + heapPointWriter);
// TODO: specialize the 1D case? it's much faster at indexing time (no partitioning on
// recurse...)
// Catch user silliness:
if (finished == true) {
throw new IllegalStateException("already finished");
}
if (pointCount == 0) {
return null;
}
// mark as finished
finished = true;
pointWriter.close();
BKDRadixSelector.PathSlice points = new BKDRadixSelector.PathSlice(pointWriter, 0, pointCount);
// clean up pointers
tempInput = null;
pointWriter = null;
final int numLeaves =
Math.toIntExact((pointCount + config.maxPointsInLeafNode - 1) / config.maxPointsInLeafNode);
final int numSplits = numLeaves - 1;
checkMaxLeafNodeCount(numLeaves);
// NOTE: we could save the 1+ here, to use a bit less heap at search time, but then we'd need a
// somewhat costly check at each
// step of the recursion to recompute the split dim:
// Indexed by nodeID, but first (root) nodeID is 1. We do 1+ because the lead byte at each
// recursion says which dim we split on.
byte[] splitPackedValues = new byte[Math.toIntExact(numSplits * config.bytesPerDim)];
byte[] splitDimensionValues = new byte[numSplits];
// +1 because leaf count is power of 2 (e.g. 8), and innerNodeCount is power of 2 minus 1 (e.g.
// 7)
long[] leafBlockFPs = new long[numLeaves];
// Make sure the math above "worked":
assert pointCount / numLeaves <= config.maxPointsInLeafNode
: "pointCount="
+ pointCount
+ " numLeaves="
+ numLeaves
+ " config.maxPointsInLeafNode="
+ config.maxPointsInLeafNode;
// We re-use the selector so we do not need to create an object every time.
BKDRadixSelector radixSelector =
new BKDRadixSelector(config, maxPointsSortInHeap, tempDir, tempFileNamePrefix);
final long dataStartFP = dataOut.getFilePointer();
boolean success = false;
try {
final int[] parentSplits = new int[config.numIndexDims];
build(
0,
numLeaves,
points,
dataOut,
radixSelector,
minPackedValue.clone(),
maxPackedValue.clone(),
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
new int[config.maxPointsInLeafNode]);
assert Arrays.equals(parentSplits, new int[config.numIndexDims]);
// If no exception, we should have cleaned everything up:
assert tempDir.getCreatedFiles().isEmpty();
// long t2 = System.nanoTime();
// System.out.println("write time: " + ((t2-t1)/1000000.0) + " msec");
success = true;
} finally {
if (success == false) {
IOUtils.deleteFilesIgnoringExceptions(tempDir, tempDir.getCreatedFiles());
}
}
scratchBytesRef1.bytes = splitPackedValues;
scratchBytesRef1.length = config.bytesPerDim;
BKDTreeLeafNodes leafNodes =
new BKDTreeLeafNodes() {
@Override
public long getLeafLP(int index) {
return leafBlockFPs[index];
}
@Override
public BytesRef getSplitValue(int index) {
scratchBytesRef1.offset = index * config.bytesPerDim;
return scratchBytesRef1;
}
@Override
public int getSplitDimension(int index) {
return splitDimensionValues[index] & 0xff;
}
@Override
public int numLeaves() {
return leafBlockFPs.length;
}
};
return () -> {
// Write index:
try {
writeIndex(metaOut, indexOut, config.maxPointsInLeafNode, leafNodes, dataStartFP);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
/**
* Packs the two arrays, representing a semi-balanced binary tree, into a compact byte[]
* structure.
*/
private byte[] packIndex(BKDTreeLeafNodes leafNodes) throws IOException {
/** Reused while packing the index */
ByteBuffersDataOutput writeBuffer = ByteBuffersDataOutput.newResettableInstance();
// This is the "file" we append the byte[] to:
List<byte[]> blocks = new ArrayList<>();
byte[] lastSplitValues = new byte[config.bytesPerDim * config.numIndexDims];
// System.out.println("\npack index");
int totalSize =
recursePackIndex(
writeBuffer,
leafNodes,
0l,
blocks,
lastSplitValues,
new boolean[config.numIndexDims],
false,
0,
leafNodes.numLeaves());
// Compact the byte[] blocks into single byte index:
byte[] index = new byte[totalSize];
int upto = 0;
for (byte[] block : blocks) {
System.arraycopy(block, 0, index, upto, block.length);
upto += block.length;
}
assert upto == totalSize;
return index;
}
/** Appends the current contents of writeBuffer as another block on the growing in-memory file */
private int appendBlock(ByteBuffersDataOutput writeBuffer, List<byte[]> blocks) {
byte[] block = writeBuffer.toArrayCopy();
blocks.add(block);
writeBuffer.reset();
return block.length;
}
/**
* lastSplitValues is per-dimension split value previously seen; we use this to prefix-code the
* split byte[] on each inner node
*/
private int recursePackIndex(
ByteBuffersDataOutput writeBuffer,
BKDTreeLeafNodes leafNodes,
long minBlockFP,
List<byte[]> blocks,
byte[] lastSplitValues,
boolean[] negativeDeltas,
boolean isLeft,
int leavesOffset,
int numLeaves)
throws IOException {
if (numLeaves == 1) {
if (isLeft) {
assert leafNodes.getLeafLP(leavesOffset) - minBlockFP == 0;
return 0;
} else {
long delta = leafNodes.getLeafLP(leavesOffset) - minBlockFP;
assert leafNodes.numLeaves() == numLeaves || delta > 0
: "expected delta > 0; got numLeaves =" + numLeaves + " and delta=" + delta;
writeBuffer.writeVLong(delta);
return appendBlock(writeBuffer, blocks);
}
} else {
long leftBlockFP;
if (isLeft) {
// The left tree's left most leaf block FP is always the minimal FP:
assert leafNodes.getLeafLP(leavesOffset) == minBlockFP;
leftBlockFP = minBlockFP;
} else {
leftBlockFP = leafNodes.getLeafLP(leavesOffset);
long delta = leftBlockFP - minBlockFP;
assert leafNodes.numLeaves() == numLeaves || delta > 0
: "expected delta > 0; got numLeaves =" + numLeaves + " and delta=" + delta;
writeBuffer.writeVLong(delta);
}
int numLeftLeafNodes = getNumLeftLeafNodes(numLeaves);
final int rightOffset = leavesOffset + numLeftLeafNodes;
final int splitOffset = rightOffset - 1;
int splitDim = leafNodes.getSplitDimension(splitOffset);
BytesRef splitValue = leafNodes.getSplitValue(splitOffset);
int address = splitValue.offset;
// System.out.println("recursePack inner nodeID=" + nodeID + " splitDim=" + splitDim + "
// splitValue=" + new BytesRef(splitPackedValues, address, config.bytesPerDim));
// find common prefix with last split value in this dim:
int prefix =
commonPrefixComparator.compare(
splitValue.bytes, address, lastSplitValues, splitDim * config.bytesPerDim);
// System.out.println("writeNodeData nodeID=" + nodeID + " splitDim=" + splitDim + " numDims="
// + numDims + " config.bytesPerDim=" + config.bytesPerDim + " prefix=" + prefix);
int firstDiffByteDelta;
if (prefix < config.bytesPerDim) {
// System.out.println(" delta byte cur=" +
// Integer.toHexString(splitPackedValues[address+prefix]&0xFF) + " prev=" +
// Integer.toHexString(lastSplitValues[splitDim * config.bytesPerDim + prefix]&0xFF) + "
// negated?=" + negativeDeltas[splitDim]);
firstDiffByteDelta =
(splitValue.bytes[address + prefix] & 0xFF)
- (lastSplitValues[splitDim * config.bytesPerDim + prefix] & 0xFF);
if (negativeDeltas[splitDim]) {
firstDiffByteDelta = -firstDiffByteDelta;
}
// System.out.println(" delta=" + firstDiffByteDelta);
assert firstDiffByteDelta > 0;
} else {
firstDiffByteDelta = 0;
}
// pack the prefix, splitDim and delta first diff byte into a single vInt:
int code =
(firstDiffByteDelta * (1 + config.bytesPerDim) + prefix) * config.numIndexDims + splitDim;
// System.out.println(" code=" + code);
// System.out.println(" splitValue=" + new BytesRef(splitPackedValues, address,
// config.bytesPerDim));
writeBuffer.writeVInt(code);
// write the split value, prefix coded vs. our parent's split value:
int suffix = config.bytesPerDim - prefix;
byte[] savSplitValue = new byte[suffix];
if (suffix > 1) {
writeBuffer.writeBytes(splitValue.bytes, address + prefix + 1, suffix - 1);
}
byte[] cmp = lastSplitValues.clone();
System.arraycopy(
lastSplitValues, splitDim * config.bytesPerDim + prefix, savSplitValue, 0, suffix);
// copy our split value into lastSplitValues for our children to prefix-code against
System.arraycopy(
splitValue.bytes,
address + prefix,
lastSplitValues,
splitDim * config.bytesPerDim + prefix,
suffix);
int numBytes = appendBlock(writeBuffer, blocks);
// placeholder for left-tree numBytes; we need this so that at search time if we only need to
// recurse into the right sub-tree we can
// quickly seek to its starting point
int idxSav = blocks.size();
blocks.add(null);
boolean savNegativeDelta = negativeDeltas[splitDim];
negativeDeltas[splitDim] = true;
int leftNumBytes =
recursePackIndex(
writeBuffer,
leafNodes,
leftBlockFP,
blocks,
lastSplitValues,
negativeDeltas,
true,
leavesOffset,
numLeftLeafNodes);
if (numLeftLeafNodes != 1) {
writeBuffer.writeVInt(leftNumBytes);
} else {
assert leftNumBytes == 0 : "leftNumBytes=" + leftNumBytes;
}
byte[] bytes2 = writeBuffer.toArrayCopy();
writeBuffer.reset();
// replace our placeholder:
blocks.set(idxSav, bytes2);
negativeDeltas[splitDim] = false;
int rightNumBytes =
recursePackIndex(
writeBuffer,
leafNodes,
leftBlockFP,
blocks,
lastSplitValues,
negativeDeltas,
false,
rightOffset,
numLeaves - numLeftLeafNodes);
negativeDeltas[splitDim] = savNegativeDelta;
// restore lastSplitValues to what caller originally passed us:
System.arraycopy(
savSplitValue, 0, lastSplitValues, splitDim * config.bytesPerDim + prefix, suffix);
assert Arrays.equals(lastSplitValues, cmp);
return numBytes + bytes2.length + leftNumBytes + rightNumBytes;
}
}
private void writeIndex(
IndexOutput metaOut,
IndexOutput indexOut,
int countPerLeaf,
BKDTreeLeafNodes leafNodes,
long dataStartFP)
throws IOException {
byte[] packedIndex = packIndex(leafNodes);
writeIndex(metaOut, indexOut, countPerLeaf, leafNodes.numLeaves(), packedIndex, dataStartFP);
}
private void writeIndex(
IndexOutput metaOut,
IndexOutput indexOut,
int countPerLeaf,
int numLeaves,
byte[] packedIndex,
long dataStartFP)
throws IOException {
CodecUtil.writeHeader(metaOut, CODEC_NAME, VERSION_CURRENT);
metaOut.writeVInt(config.numDims);
metaOut.writeVInt(config.numIndexDims);
metaOut.writeVInt(countPerLeaf);
metaOut.writeVInt(config.bytesPerDim);
assert numLeaves > 0;
metaOut.writeVInt(numLeaves);
metaOut.writeBytes(minPackedValue, 0, config.packedIndexBytesLength);
metaOut.writeBytes(maxPackedValue, 0, config.packedIndexBytesLength);
metaOut.writeVLong(pointCount);
metaOut.writeVInt(docsSeen.cardinality());
metaOut.writeVInt(packedIndex.length);
metaOut.writeLong(dataStartFP);
// If metaOut and indexOut are the same file, we account for the fact that
// writing a long makes the index start 8 bytes later.
metaOut.writeLong(indexOut.getFilePointer() + (metaOut == indexOut ? Long.BYTES : 0));
indexOut.writeBytes(packedIndex, 0, packedIndex.length);
}
private void writeLeafBlockDocs(DataOutput out, int[] docIDs, int start, int count)
throws IOException {
assert count > 0 : "config.maxPointsInLeafNode=" + config.maxPointsInLeafNode;
out.writeVInt(count);
docIdsWriter.writeDocIds(docIDs, start, count, out);
}
private void writeLeafBlockPackedValues(
DataOutput out,
int[] commonPrefixLengths,
int count,
int sortedDim,
IntFunction<BytesRef> packedValues,
int leafCardinality)
throws IOException {
int prefixLenSum = Arrays.stream(commonPrefixLengths).sum();
if (prefixLenSum == config.packedBytesLength) {
// all values in this block are equal
out.writeByte((byte) -1);
} else {
assert commonPrefixLengths[sortedDim] < config.bytesPerDim;
// estimate if storing the values with cardinality is cheaper than storing all values.
int compressedByteOffset = sortedDim * config.bytesPerDim + commonPrefixLengths[sortedDim];
int highCardinalityCost;
int lowCardinalityCost;
if (count == leafCardinality) {
// all values in this block are different
highCardinalityCost = 0;
lowCardinalityCost = 1;
} else {
// compute cost of runLen compression
int numRunLens = 0;
for (int i = 0; i < count; ) {
// do run-length compression on the byte at compressedByteOffset
int runLen = runLen(packedValues, i, Math.min(i + 0xff, count), compressedByteOffset);
assert runLen <= 0xff;
numRunLens++;
i += runLen;
}
// Add cost of runLen compression
highCardinalityCost =
count * (config.packedBytesLength - prefixLenSum - 1) + 2 * numRunLens;
// +1 is the byte needed for storing the cardinality
lowCardinalityCost = leafCardinality * (config.packedBytesLength - prefixLenSum + 1);
}
if (lowCardinalityCost <= highCardinalityCost) {
out.writeByte((byte) -2);
writeLowCardinalityLeafBlockPackedValues(out, commonPrefixLengths, count, packedValues);
} else {
out.writeByte((byte) sortedDim);
writeHighCardinalityLeafBlockPackedValues(
out, commonPrefixLengths, count, sortedDim, packedValues, compressedByteOffset);
}
}
}
private void writeLowCardinalityLeafBlockPackedValues(
DataOutput out, int[] commonPrefixLengths, int count, IntFunction<BytesRef> packedValues)
throws IOException {
if (config.numIndexDims != 1) {
writeActualBounds(out, commonPrefixLengths, count, packedValues);
}
BytesRef value = packedValues.apply(0);
System.arraycopy(value.bytes, value.offset, scratch1, 0, config.packedBytesLength);
int cardinality = 1;
for (int i = 1; i < count; i++) {
value = packedValues.apply(i);
for (int dim = 0; dim < config.numDims; dim++) {
final int start = dim * config.bytesPerDim;
if (equalsPredicate.test(value.bytes, value.offset + start, scratch1, start) == false) {
out.writeVInt(cardinality);
for (int j = 0; j < config.numDims; j++) {
out.writeBytes(
scratch1,
j * config.bytesPerDim + commonPrefixLengths[j],
config.bytesPerDim - commonPrefixLengths[j]);
}
System.arraycopy(value.bytes, value.offset, scratch1, 0, config.packedBytesLength);
cardinality = 1;
break;
} else if (dim == config.numDims - 1) {
cardinality++;
}
}
}
out.writeVInt(cardinality);
for (int i = 0; i < config.numDims; i++) {
out.writeBytes(
scratch1,
i * config.bytesPerDim + commonPrefixLengths[i],
config.bytesPerDim - commonPrefixLengths[i]);
}
}
private void writeHighCardinalityLeafBlockPackedValues(
DataOutput out,
int[] commonPrefixLengths,
int count,
int sortedDim,
IntFunction<BytesRef> packedValues,
int compressedByteOffset)
throws IOException {
if (config.numIndexDims != 1) {
writeActualBounds(out, commonPrefixLengths, count, packedValues);
}
commonPrefixLengths[sortedDim]++;
for (int i = 0; i < count; ) {
// do run-length compression on the byte at compressedByteOffset
int runLen = runLen(packedValues, i, Math.min(i + 0xff, count), compressedByteOffset);
assert runLen <= 0xff;
BytesRef first = packedValues.apply(i);
byte prefixByte = first.bytes[first.offset + compressedByteOffset];
out.writeByte(prefixByte);
out.writeByte((byte) runLen);
writeLeafBlockPackedValuesRange(out, commonPrefixLengths, i, i + runLen, packedValues);
i += runLen;
assert i <= count;
}
}
private void writeActualBounds(
DataOutput out, int[] commonPrefixLengths, int count, IntFunction<BytesRef> packedValues)
throws IOException {
for (int dim = 0; dim < config.numIndexDims; ++dim) {
int commonPrefixLength = commonPrefixLengths[dim];
int suffixLength = config.bytesPerDim - commonPrefixLength;
if (suffixLength > 0) {
BytesRef[] minMax =
computeMinMax(
count, packedValues, dim * config.bytesPerDim + commonPrefixLength, suffixLength);
BytesRef min = minMax[0];
BytesRef max = minMax[1];
out.writeBytes(min.bytes, min.offset, min.length);
out.writeBytes(max.bytes, max.offset, max.length);
}
}
}
/**
* Return an array that contains the min and max values for the [offset, offset+length] interval
* of the given {@link BytesRef}s.
*/
private static BytesRef[] computeMinMax(
int count, IntFunction<BytesRef> packedValues, int offset, int length) {
assert length > 0;
BytesRefBuilder min = new BytesRefBuilder();
BytesRefBuilder max = new BytesRefBuilder();
BytesRef first = packedValues.apply(0);
min.copyBytes(first.bytes, first.offset + offset, length);
max.copyBytes(first.bytes, first.offset + offset, length);
for (int i = 1; i < count; ++i) {
BytesRef candidate = packedValues.apply(i);
if (Arrays.compareUnsigned(
min.bytes(),
0,
length,
candidate.bytes,
candidate.offset + offset,
candidate.offset + offset + length)
> 0) {
min.copyBytes(candidate.bytes, candidate.offset + offset, length);
} else if (Arrays.compareUnsigned(
max.bytes(),
0,
length,
candidate.bytes,
candidate.offset + offset,
candidate.offset + offset + length)
< 0) {
max.copyBytes(candidate.bytes, candidate.offset + offset, length);
}
}
return new BytesRef[] {min.get(), max.get()};
}
private void writeLeafBlockPackedValuesRange(
DataOutput out,
int[] commonPrefixLengths,
int start,
int end,
IntFunction<BytesRef> packedValues)
throws IOException {
for (int i = start; i < end; ++i) {
BytesRef ref = packedValues.apply(i);
assert ref.length == config.packedBytesLength;
for (int dim = 0; dim < config.numDims; dim++) {
int prefix = commonPrefixLengths[dim];
out.writeBytes(
ref.bytes, ref.offset + dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
}
}
}
private static int runLen(
IntFunction<BytesRef> packedValues, int start, int end, int byteOffset) {
BytesRef first = packedValues.apply(start);
byte b = first.bytes[first.offset + byteOffset];
for (int i = start + 1; i < end; ++i) {
BytesRef ref = packedValues.apply(i);
byte b2 = ref.bytes[ref.offset + byteOffset];
assert Byte.toUnsignedInt(b2) >= Byte.toUnsignedInt(b);
if (b != b2) {
return i - start;
}
}
return end - start;
}
private void writeCommonPrefixes(DataOutput out, int[] commonPrefixes, byte[] packedValue)
throws IOException {
for (int dim = 0; dim < config.numDims; dim++) {
out.writeVInt(commonPrefixes[dim]);
// System.out.println(commonPrefixes[dim] + " of " + config.bytesPerDim);
out.writeBytes(packedValue, dim * config.bytesPerDim, commonPrefixes[dim]);
}
}
@Override
public void close() throws IOException {
finished = true;
if (tempInput != null) {
// NOTE: this should only happen on exception, e.g. caller calls close w/o calling finish:
try {
tempInput.close();
} finally {
tempDir.deleteFile(tempInput.getName());
tempInput = null;
}
}
}
/**
* Called on exception, to check whether the checksum is also corrupt in this source, and add that
* information (checksum matched or didn't) as a suppressed exception.
*/
private Error verifyChecksum(Throwable priorException, PointWriter writer) throws IOException {
assert priorException != null;
// TODO: we could improve this, to always validate checksum as we recurse, if we shared left and
// right reader after recursing to children, and possibly within recursed children,
// since all together they make a single pass through the file. But this is a sizable re-org,
// and would mean leaving readers (IndexInputs) open for longer:
if (writer instanceof OfflinePointWriter) {
// We are reading from a temp file; go verify the checksum:
String tempFileName = ((OfflinePointWriter) writer).name;
if (tempDir.getCreatedFiles().contains(tempFileName)) {
try (ChecksumIndexInput in = tempDir.openChecksumInput(tempFileName, IOContext.READONCE)) {
CodecUtil.checkFooter(in, priorException);
}
}
}
// We are reading from heap; nothing to add:
throw IOUtils.rethrowAlways(priorException);
}
/**
* Pick the next dimension to split.
*
* @param minPackedValue the min values for all dimensions
* @param maxPackedValue the max values for all dimensions
* @param parentSplits how many times each dim has been split on the parent levels
* @return the dimension to split
*/
protected int split(byte[] minPackedValue, byte[] maxPackedValue, int[] parentSplits) {
// First look at whether there is a dimension that has split less than 2x less than
// the dim that has most splits, and return it if there is such a dimension and it
// does not only have equals values. This helps ensure all dimensions are indexed.
int maxNumSplits = 0;
for (int numSplits : parentSplits) {
maxNumSplits = Math.max(maxNumSplits, numSplits);
}
for (int dim = 0; dim < config.numIndexDims; ++dim) {
final int offset = dim * config.bytesPerDim;
if (parentSplits[dim] < maxNumSplits / 2
&& comparator.compare(minPackedValue, offset, maxPackedValue, offset) != 0) {
return dim;
}
}
// Find which dim has the largest span so we can split on it:
int splitDim = -1;
for (int dim = 0; dim < config.numIndexDims; dim++) {
NumericUtils.subtract(config.bytesPerDim, dim, maxPackedValue, minPackedValue, scratchDiff);
if (splitDim == -1 || comparator.compare(scratchDiff, 0, scratch1, 0) > 0) {
System.arraycopy(scratchDiff, 0, scratch1, 0, config.bytesPerDim);
splitDim = dim;
}
}
// System.out.println("SPLIT: " + splitDim);
return splitDim;
}
/** Pull a partition back into heap once the point count is low enough while recursing. */
private HeapPointWriter switchToHeap(PointWriter source) throws IOException {
int count = Math.toIntExact(source.count());
try (PointReader reader = source.getReader(0, source.count());
HeapPointWriter writer = new HeapPointWriter(config, count)) {
for (int i = 0; i < count; i++) {
boolean hasNext = reader.next();
assert hasNext;
writer.append(reader.pointValue());
}
source.destroy();
return writer;
} catch (Throwable t) {
throw verifyChecksum(t, source);
}
}
/* Recursively reorders the provided reader and writes the bkd-tree on the fly; this method is used
* when we are writing a new segment directly from IndexWriter's indexing buffer (MutablePointsReader). */
private void build(
int leavesOffset,
int numLeaves,
MutablePointTree reader,
int from,
int to,
IndexOutput out,
byte[] minPackedValue,
byte[] maxPackedValue,
int[] parentSplits,
byte[] splitPackedValues,
byte[] splitDimensionValues,
long[] leafBlockFPs,
int[] spareDocIds)
throws IOException {
if (numLeaves == 1) {
// leaf node
final int count = to - from;
assert count <= config.maxPointsInLeafNode;
// Compute common prefixes
Arrays.fill(commonPrefixLengths, config.bytesPerDim);
reader.getValue(from, scratchBytesRef1);
for (int i = from + 1; i < to; ++i) {
reader.getValue(i, scratchBytesRef2);
for (int dim = 0; dim < config.numDims; dim++) {
final int offset = dim * config.bytesPerDim;
int dimensionPrefixLength = commonPrefixLengths[dim];
commonPrefixLengths[dim] =
Math.min(
dimensionPrefixLength,
commonPrefixComparator.compare(
scratchBytesRef1.bytes,
scratchBytesRef1.offset + offset,
scratchBytesRef2.bytes,
scratchBytesRef2.offset + offset));
}
}
// Find the dimension that has the least number of unique bytes at commonPrefixLengths[dim]
FixedBitSet[] usedBytes = new FixedBitSet[config.numDims];
for (int dim = 0; dim < config.numDims; ++dim) {
if (commonPrefixLengths[dim] < config.bytesPerDim) {
usedBytes[dim] = new FixedBitSet(256);
}
}
for (int i = from + 1; i < to; ++i) {
for (int dim = 0; dim < config.numDims; dim++) {
if (usedBytes[dim] != null) {
byte b = reader.getByteAt(i, dim * config.bytesPerDim + commonPrefixLengths[dim]);
usedBytes[dim].set(Byte.toUnsignedInt(b));
}
}
}
int sortedDim = 0;
int sortedDimCardinality = Integer.MAX_VALUE;
for (int dim = 0; dim < config.numDims; ++dim) {
if (usedBytes[dim] != null) {
final int cardinality = usedBytes[dim].cardinality();
if (cardinality < sortedDimCardinality) {
sortedDim = dim;
sortedDimCardinality = cardinality;
}
}
}
// sort by sortedDim
MutablePointTreeReaderUtils.sortByDim(
config,
sortedDim,
commonPrefixLengths,
reader,
from,
to,
scratchBytesRef1,
scratchBytesRef2);
BytesRef comparator = scratchBytesRef1;
BytesRef collector = scratchBytesRef2;
reader.getValue(from, comparator);
int leafCardinality = 1;
for (int i = from + 1; i < to; ++i) {
reader.getValue(i, collector);
for (int dim = 0; dim < config.numDims; dim++) {
final int start = dim * config.bytesPerDim;
if (equalsPredicate.test(
collector.bytes,
collector.offset + start,
comparator.bytes,
comparator.offset + start)
== false) {
leafCardinality++;
BytesRef scratch = collector;
collector = comparator;
comparator = scratch;
break;
}
}
}
// Save the block file pointer:
leafBlockFPs[leavesOffset] = out.getFilePointer();
// Write doc IDs
int[] docIDs = spareDocIds;
for (int i = from; i < to; ++i) {
docIDs[i - from] = reader.getDocID(i);
}
// System.out.println("writeLeafBlock pos=" + out.getFilePointer());
writeLeafBlockDocs(out, docIDs, 0, count);
// Write the common prefixes:
reader.getValue(from, scratchBytesRef1);
System.arraycopy(
scratchBytesRef1.bytes, scratchBytesRef1.offset, scratch1, 0, config.packedBytesLength);
writeCommonPrefixes(out, commonPrefixLengths, scratch1);
// Write the full values:
IntFunction<BytesRef> packedValues =
new IntFunction<BytesRef>() {
@Override
public BytesRef apply(int i) {
reader.getValue(from + i, scratchBytesRef1);
return scratchBytesRef1;
}
};
assert valuesInOrderAndBounds(
config, count, sortedDim, minPackedValue, maxPackedValue, packedValues, docIDs, 0);
writeLeafBlockPackedValues(
out, commonPrefixLengths, count, sortedDim, packedValues, leafCardinality);
} else {
// inner node
final int splitDim;
// compute the split dimension and partition around it
if (config.numIndexDims == 1) {
splitDim = 0;
} else {
// for dimensions > 2 we recompute the bounds for the current inner node to help the
// algorithm choose best
// split dimensions. Because it is an expensive operation, the frequency we recompute the
// bounds is given
// by SPLITS_BEFORE_EXACT_BOUNDS.
if (numLeaves != leafBlockFPs.length
&& config.numIndexDims > 2
&& Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
computePackedValueBounds(
reader, from, to, minPackedValue, maxPackedValue, scratchBytesRef1);
}
splitDim = split(minPackedValue, maxPackedValue, parentSplits);
}
// How many leaves will be in the left tree:
int numLeftLeafNodes = getNumLeftLeafNodes(numLeaves);
// How many points will be in the left tree:
final int mid = from + numLeftLeafNodes * config.maxPointsInLeafNode;
final int commonPrefixLen =
commonPrefixComparator.compare(
minPackedValue,
splitDim * config.bytesPerDim,
maxPackedValue,
splitDim * config.bytesPerDim);
MutablePointTreeReaderUtils.partition(
config,
maxDoc,
splitDim,
commonPrefixLen,
reader,
from,
to,
mid,
scratchBytesRef1,
scratchBytesRef2);
final int rightOffset = leavesOffset + numLeftLeafNodes;
final int splitOffset = rightOffset - 1;
// set the split value
final int address = splitOffset * config.bytesPerDim;
splitDimensionValues[splitOffset] = (byte) splitDim;
reader.getValue(mid, scratchBytesRef1);
System.arraycopy(
scratchBytesRef1.bytes,
scratchBytesRef1.offset + splitDim * config.bytesPerDim,
splitPackedValues,
address,
config.bytesPerDim);
byte[] minSplitPackedValue =
ArrayUtil.copyOfSubArray(minPackedValue, 0, config.packedIndexBytesLength);
byte[] maxSplitPackedValue =
ArrayUtil.copyOfSubArray(maxPackedValue, 0, config.packedIndexBytesLength);
System.arraycopy(
scratchBytesRef1.bytes,
scratchBytesRef1.offset + splitDim * config.bytesPerDim,
minSplitPackedValue,
splitDim * config.bytesPerDim,
config.bytesPerDim);
System.arraycopy(
scratchBytesRef1.bytes,
scratchBytesRef1.offset + splitDim * config.bytesPerDim,
maxSplitPackedValue,
splitDim * config.bytesPerDim,
config.bytesPerDim);
// recurse
parentSplits[splitDim]++;
build(
leavesOffset,
numLeftLeafNodes,
reader,
from,
mid,
out,
minPackedValue,
maxSplitPackedValue,
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
spareDocIds);
build(
rightOffset,
numLeaves - numLeftLeafNodes,
reader,
mid,
to,
out,
minSplitPackedValue,
maxPackedValue,
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
spareDocIds);
parentSplits[splitDim]--;
}
}
private void computePackedValueBounds(
BKDRadixSelector.PathSlice slice, byte[] minPackedValue, byte[] maxPackedValue)
throws IOException {
try (PointReader reader = slice.writer.getReader(slice.start, slice.count)) {
if (reader.next() == false) {
return;
}
BytesRef value = reader.pointValue().packedValue();
System.arraycopy(value.bytes, value.offset, minPackedValue, 0, config.packedIndexBytesLength);
System.arraycopy(value.bytes, value.offset, maxPackedValue, 0, config.packedIndexBytesLength);
while (reader.next()) {
value = reader.pointValue().packedValue();
for (int dim = 0; dim < config.numIndexDims; dim++) {
final int startOffset = dim * config.bytesPerDim;
if (comparator.compare(
value.bytes, value.offset + startOffset, minPackedValue, startOffset)
< 0) {
System.arraycopy(
value.bytes,
value.offset + startOffset,
minPackedValue,
startOffset,
config.bytesPerDim);
} else if (comparator.compare(
value.bytes, value.offset + startOffset, maxPackedValue, startOffset)
> 0) {
System.arraycopy(
value.bytes,
value.offset + startOffset,
maxPackedValue,
startOffset,
config.bytesPerDim);
}
}
}
}
}
/**
* The point writer contains the data that is going to be splitted using radix selection. /* This
* method is used when we are merging previously written segments, in the numDims > 1 case.
*/
private void build(
int leavesOffset,
int numLeaves,
BKDRadixSelector.PathSlice points,
IndexOutput out,
BKDRadixSelector radixSelector,
byte[] minPackedValue,
byte[] maxPackedValue,
int[] parentSplits,
byte[] splitPackedValues,
byte[] splitDimensionValues,
long[] leafBlockFPs,
int[] spareDocIds)
throws IOException {
if (numLeaves == 1) {
// Leaf node: write block
// We can write the block in any order so by default we write it sorted by the dimension that
// has the
// least number of unique bytes at commonPrefixLengths[dim], which makes compression more
// efficient
HeapPointWriter heapSource;
if (points.writer instanceof HeapPointWriter == false) {
// Adversarial cases can cause this, e.g. merging big segments with most of the points
// deleted
heapSource = switchToHeap(points.writer);
} else {
heapSource = (HeapPointWriter) points.writer;
}
int from = Math.toIntExact(points.start);
int to = Math.toIntExact(points.start + points.count);
// we store common prefix on scratch1
computeCommonPrefixLength(heapSource, scratch1, from, to);
int sortedDim = 0;
int sortedDimCardinality = Integer.MAX_VALUE;
FixedBitSet[] usedBytes = new FixedBitSet[config.numDims];
for (int dim = 0; dim < config.numDims; ++dim) {
if (commonPrefixLengths[dim] < config.bytesPerDim) {
usedBytes[dim] = new FixedBitSet(256);
}
}
// Find the dimension to compress
for (int dim = 0; dim < config.numDims; dim++) {
int prefix = commonPrefixLengths[dim];
if (prefix < config.bytesPerDim) {
int offset = dim * config.bytesPerDim;
for (int i = from; i < to; ++i) {
PointValue value = heapSource.getPackedValueSlice(i);
BytesRef packedValue = value.packedValue();
int bucket = packedValue.bytes[packedValue.offset + offset + prefix] & 0xff;
usedBytes[dim].set(bucket);
}
int cardinality = usedBytes[dim].cardinality();
if (cardinality < sortedDimCardinality) {
sortedDim = dim;
sortedDimCardinality = cardinality;
}
}
}
// sort the chosen dimension
radixSelector.heapRadixSort(heapSource, from, to, sortedDim, commonPrefixLengths[sortedDim]);
// compute cardinality
int leafCardinality = heapSource.computeCardinality(from, to, commonPrefixLengths);
// Save the block file pointer:
leafBlockFPs[leavesOffset] = out.getFilePointer();
// System.out.println(" write leaf block @ fp=" + out.getFilePointer());
// Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
// loading the values:
int count = to - from;
assert count > 0 : "numLeaves=" + numLeaves + " leavesOffset=" + leavesOffset;
assert count <= spareDocIds.length : "count=" + count + " > length=" + spareDocIds.length;
// Write doc IDs
int[] docIDs = spareDocIds;
for (int i = 0; i < count; i++) {
docIDs[i] = heapSource.getPackedValueSlice(from + i).docID();
}
writeLeafBlockDocs(out, docIDs, 0, count);
// TODO: minor opto: we don't really have to write the actual common prefixes, because
// BKDReader on recursing can regenerate it for us
// from the index, much like how terms dict does so from the FST:
// Write the common prefixes:
writeCommonPrefixes(out, commonPrefixLengths, scratch1);
// Write the full values:
IntFunction<BytesRef> packedValues =
new IntFunction<BytesRef>() {
final BytesRef scratch = new BytesRef();
{
scratch.length = config.packedBytesLength;
}
@Override
public BytesRef apply(int i) {
PointValue value = heapSource.getPackedValueSlice(from + i);
return value.packedValue();
}
};
assert valuesInOrderAndBounds(
config, count, sortedDim, minPackedValue, maxPackedValue, packedValues, docIDs, 0);
writeLeafBlockPackedValues(
out, commonPrefixLengths, count, sortedDim, packedValues, leafCardinality);
} else {
// Inner node: partition/recurse
final int splitDim;
if (config.numIndexDims == 1) {
splitDim = 0;
} else {
// for dimensions > 2 we recompute the bounds for the current inner node to help the
// algorithm choose best
// split dimensions. Because it is an expensive operation, the frequency we recompute the
// bounds is given
// by SPLITS_BEFORE_EXACT_BOUNDS.
if (numLeaves != leafBlockFPs.length
&& config.numIndexDims > 2
&& Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
computePackedValueBounds(points, minPackedValue, maxPackedValue);
}
splitDim = split(minPackedValue, maxPackedValue, parentSplits);
}
assert numLeaves <= leafBlockFPs.length
: "numLeaves=" + numLeaves + " leafBlockFPs.length=" + leafBlockFPs.length;
// How many leaves will be in the left tree:
final int numLeftLeafNodes = getNumLeftLeafNodes(numLeaves);
// How many points will be in the left tree:
final long leftCount = numLeftLeafNodes * (long) config.maxPointsInLeafNode;
BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2];
final int commonPrefixLen =
commonPrefixComparator.compare(
minPackedValue,
splitDim * config.bytesPerDim,
maxPackedValue,
splitDim * config.bytesPerDim);
byte[] splitValue =
radixSelector.select(
points,
slices,
points.start,
points.start + points.count,
points.start + leftCount,
splitDim,
commonPrefixLen);
final int rightOffset = leavesOffset + numLeftLeafNodes;
final int splitValueOffset = rightOffset - 1;
splitDimensionValues[splitValueOffset] = (byte) splitDim;
int address = splitValueOffset * config.bytesPerDim;
System.arraycopy(splitValue, 0, splitPackedValues, address, config.bytesPerDim);
byte[] minSplitPackedValue = new byte[config.packedIndexBytesLength];
System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, config.packedIndexBytesLength);
byte[] maxSplitPackedValue = new byte[config.packedIndexBytesLength];
System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, config.packedIndexBytesLength);
System.arraycopy(
splitValue, 0, minSplitPackedValue, splitDim * config.bytesPerDim, config.bytesPerDim);
System.arraycopy(
splitValue, 0, maxSplitPackedValue, splitDim * config.bytesPerDim, config.bytesPerDim);
parentSplits[splitDim]++;
// Recurse on left tree:
build(
leavesOffset,
numLeftLeafNodes,
slices[0],
out,
radixSelector,
minPackedValue,
maxSplitPackedValue,
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
spareDocIds);
// Recurse on right tree:
build(
rightOffset,
numLeaves - numLeftLeafNodes,
slices[1],
out,
radixSelector,
minSplitPackedValue,
maxPackedValue,
parentSplits,
splitPackedValues,
splitDimensionValues,
leafBlockFPs,
spareDocIds);
parentSplits[splitDim]--;
}
}
private void computeCommonPrefixLength(
HeapPointWriter heapPointWriter, byte[] commonPrefix, int from, int to) {
Arrays.fill(commonPrefixLengths, config.bytesPerDim);
PointValue value = heapPointWriter.getPackedValueSlice(from);
BytesRef packedValue = value.packedValue();
for (int dim = 0; dim < config.numDims; dim++) {
System.arraycopy(
packedValue.bytes,
packedValue.offset + dim * config.bytesPerDim,
commonPrefix,
dim * config.bytesPerDim,
config.bytesPerDim);
}
for (int i = from + 1; i < to; i++) {
value = heapPointWriter.getPackedValueSlice(i);
packedValue = value.packedValue();
for (int dim = 0; dim < config.numDims; dim++) {
if (commonPrefixLengths[dim] != 0) {
commonPrefixLengths[dim] =
Math.min(
commonPrefixLengths[dim],
commonPrefixComparator.compare(
commonPrefix,
dim * config.bytesPerDim,
packedValue.bytes,
packedValue.offset + dim * config.bytesPerDim));
}
}
}
}
// only called from assert
private static boolean valuesInOrderAndBounds(
BKDConfig config,
int count,
int sortedDim,
byte[] minPackedValue,
byte[] maxPackedValue,
IntFunction<BytesRef> values,
int[] docs,
int docsOffset) {
byte[] lastPackedValue = new byte[config.packedBytesLength];
int lastDoc = -1;
for (int i = 0; i < count; i++) {
BytesRef packedValue = values.apply(i);
assert packedValue.length == config.packedBytesLength;
assert valueInOrder(
config,
i,
sortedDim,
lastPackedValue,
packedValue.bytes,
packedValue.offset,
docs[docsOffset + i],
lastDoc);
lastDoc = docs[docsOffset + i];
// Make sure this value does in fact fall within this leaf cell:
assert valueInBounds(config, packedValue, minPackedValue, maxPackedValue);
}
return true;
}
// only called from assert
private static boolean valueInOrder(
BKDConfig config,
long ord,
int sortedDim,
byte[] lastPackedValue,
byte[] packedValue,
int packedValueOffset,
int doc,
int lastDoc) {
int dimOffset = sortedDim * config.bytesPerDim;
if (ord > 0) {
int cmp =
Arrays.compareUnsigned(
lastPackedValue,
dimOffset,
dimOffset + config.bytesPerDim,
packedValue,
packedValueOffset + dimOffset,
packedValueOffset + dimOffset + config.bytesPerDim);
if (cmp > 0) {
throw new AssertionError(
"values out of order: last value="
+ new BytesRef(lastPackedValue)
+ " current value="
+ new BytesRef(packedValue, packedValueOffset, config.packedBytesLength)
+ " ord="
+ ord);
}
if (cmp == 0 && config.numDims > config.numIndexDims) {
cmp =
Arrays.compareUnsigned(
lastPackedValue,
config.packedIndexBytesLength,
config.packedBytesLength,
packedValue,
packedValueOffset + config.packedIndexBytesLength,
packedValueOffset + config.packedBytesLength);
if (cmp > 0) {
throw new AssertionError(
"data values out of order: last value="
+ new BytesRef(lastPackedValue)
+ " current value="
+ new BytesRef(packedValue, packedValueOffset, config.packedBytesLength)
+ " ord="
+ ord);
}
}
if (cmp == 0 && doc < lastDoc) {
throw new AssertionError(
"docs out of order: last doc=" + lastDoc + " current doc=" + doc + " ord=" + ord);
}
}
System.arraycopy(packedValue, packedValueOffset, lastPackedValue, 0, config.packedBytesLength);
return true;
}
// only called from assert
private static boolean valueInBounds(
BKDConfig config, BytesRef packedValue, byte[] minPackedValue, byte[] maxPackedValue) {
for (int dim = 0; dim < config.numIndexDims; dim++) {
int offset = config.bytesPerDim * dim;
if (Arrays.compareUnsigned(
packedValue.bytes,
packedValue.offset + offset,
packedValue.offset + offset + config.bytesPerDim,
minPackedValue,
offset,
offset + config.bytesPerDim)
< 0) {
return false;
}
if (Arrays.compareUnsigned(
packedValue.bytes,
packedValue.offset + offset,
packedValue.offset + offset + config.bytesPerDim,
maxPackedValue,
offset,
offset + config.bytesPerDim)
> 0) {
return false;
}
}
return true;
}
}