blob: 8a6c17e5a91eb2532fc230fd39002bdb22ddce57 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.util.Arrays;
/** Radix selector.
* <p>This implementation works similarly to a MSB radix sort except that it
* only recurses into the sub partition that contains the desired value.
* @lucene.internal */
public abstract class RadixSelector extends Selector {
// after that many levels of recursion we fall back to introselect anyway
// this is used as a protection against the fact that radix sort performs
// worse when there are long common prefixes (probably because of cache
// locality)
private static final int LEVEL_THRESHOLD = 8;
// size of histograms: 256 + 1 to indicate that the string is finished
private static final int HISTOGRAM_SIZE = 257;
// buckets below this size will be sorted with introselect
private static final int LENGTH_THRESHOLD = 100;
// we store one histogram per recursion level
private final int[] histogram = new int[HISTOGRAM_SIZE];
private final int[] commonPrefix;
private final int maxLength;
/**
* Sole constructor.
* @param maxLength the maximum length of keys, pass {@link Integer#MAX_VALUE} if unknown.
*/
protected RadixSelector(int maxLength) {
this.maxLength = maxLength;
this.commonPrefix = new int[Math.min(24, maxLength)];
}
/** Return the k-th byte of the entry at index {@code i}, or {@code -1} if
* its length is less than or equal to {@code k}. This may only be called
* with a value of {@code i} between {@code 0} included and
* {@code maxLength} excluded. */
protected abstract int byteAt(int i, int k);
/** Get a fall-back selector which may assume that the first {@code d} bytes
* of all compared strings are equal. This fallback selector is used when
* the range becomes narrow or when the maximum level of recursion has
* been exceeded. */
protected Selector getFallbackSelector(int d) {
return new IntroSelector() {
@Override
protected void swap(int i, int j) {
RadixSelector.this.swap(i, j);
}
@Override
protected int compare(int i, int j) {
for (int o = d; o < maxLength; ++o) {
final int b1 = byteAt(i, o);
final int b2 = byteAt(j, o);
if (b1 != b2) {
return b1 - b2;
} else if (b1 == -1) {
break;
}
}
return 0;
}
@Override
protected void setPivot(int i) {
pivot.setLength(0);
for (int o = d; o < maxLength; ++o) {
final int b = byteAt(i, o);
if (b == -1) {
break;
}
pivot.append((byte) b);
}
}
@Override
protected int comparePivot(int j) {
for (int o = 0; o < pivot.length(); ++o) {
final int b1 = pivot.byteAt(o) & 0xff;
final int b2 = byteAt(j, d + o);
if (b1 != b2) {
return b1 - b2;
}
}
if (d + pivot.length() == maxLength) {
return 0;
}
return -1 - byteAt(j, d + pivot.length());
}
private final BytesRefBuilder pivot = new BytesRefBuilder();
};
}
@Override
public void select(int from, int to, int k) {
checkArgs(from, to, k);
select(from, to, k, 0, 0);
}
private void select(int from, int to, int k, int d, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
getFallbackSelector(d).select(from, to, k);
} else {
radixSelect(from, to, k, d, l);
}
}
/**
* @param d the character number to compare
* @param l the level of recursion
*/
private void radixSelect(int from, int to, int k, int d, int l) {
final int[] histogram = this.histogram;
Arrays.fill(histogram, 0);
final int commonPrefixLength = computeCommonPrefixLengthAndBuildHistogram(from, to, d, histogram);
if (commonPrefixLength > 0) {
// if there are no more chars to compare or if all entries fell into the
// first bucket (which means strings are shorter than d) then we are done
// otherwise recurse
if (d + commonPrefixLength < maxLength
&& histogram[0] < to - from) {
radixSelect(from, to, k, d + commonPrefixLength, l);
}
return;
}
assert assertHistogram(commonPrefixLength, histogram);
int bucketFrom = from;
for (int bucket = 0; bucket < HISTOGRAM_SIZE; ++bucket) {
final int bucketTo = bucketFrom + histogram[bucket];
if (bucketTo > k) {
partition(from, to, bucket, bucketFrom, bucketTo, d);
if (bucket != 0 && d + 1 < maxLength) {
// all elements in bucket 0 are equal so we only need to recurse if bucket != 0
select(bucketFrom, bucketTo, k, d + 1, l + 1);
}
return;
}
bucketFrom = bucketTo;
}
throw new AssertionError("Unreachable code");
}
// only used from assert
private boolean assertHistogram(int commonPrefixLength, int[] histogram) {
int numberOfUniqueBytes = 0;
for (int freq : histogram) {
if (freq > 0) {
numberOfUniqueBytes++;
}
}
if (numberOfUniqueBytes == 1) {
assert commonPrefixLength >= 1;
} else {
assert commonPrefixLength == 0;
}
return true;
}
/** Return a number for the k-th character between 0 and {@link #HISTOGRAM_SIZE}. */
private int getBucket(int i, int k) {
return byteAt(i, k) + 1;
}
/** Build a histogram of the number of values per {@link #getBucket(int, int) bucket}
* and return a common prefix length for all visited values.
* @see #buildHistogram */
private int computeCommonPrefixLengthAndBuildHistogram(int from, int to, int k, int[] histogram) {
final int[] commonPrefix = this.commonPrefix;
int commonPrefixLength = Math.min(commonPrefix.length, maxLength - k);
for (int j = 0; j < commonPrefixLength; ++j) {
final int b = byteAt(from, k + j);
commonPrefix[j] = b;
if (b == -1) {
commonPrefixLength = j + 1;
break;
}
}
int i;
outer: for (i = from + 1; i < to; ++i) {
for (int j = 0; j < commonPrefixLength; ++j) {
final int b = byteAt(i, k + j);
if (b != commonPrefix[j]) {
commonPrefixLength = j;
if (commonPrefixLength == 0) { // we have no common prefix
histogram[commonPrefix[0] + 1] = i - from;
histogram[b + 1] = 1;
break outer;
}
break;
}
}
}
if (i < to) {
// the loop got broken because there is no common prefix
assert commonPrefixLength == 0;
buildHistogram(i + 1, to, k, histogram);
} else {
assert commonPrefixLength > 0;
histogram[commonPrefix[0] + 1] = to - from;
}
return commonPrefixLength;
}
/** Build an histogram of the k-th characters of values occurring between
* offsets {@code from} and {@code to}, using {@link #getBucket}. */
private void buildHistogram(int from, int to, int k, int[] histogram) {
for (int i = from; i < to; ++i) {
histogram[getBucket(i, k)]++;
}
}
/** Reorder elements so that all of them that fall into {@code bucket} are
* between offsets {@code bucketFrom} and {@code bucketTo}. */
private void partition(int from, int to, int bucket, int bucketFrom, int bucketTo, int d) {
int left = from;
int right = to - 1;
int slot = bucketFrom;
for (;;) {
int leftBucket = getBucket(left, d);
int rightBucket = getBucket(right, d);
while (leftBucket <= bucket && left < bucketFrom) {
if (leftBucket == bucket) {
swap(left, slot++);
} else {
++left;
}
leftBucket = getBucket(left, d);
}
while (rightBucket >= bucket && right >= bucketTo) {
if (rightBucket == bucket) {
swap(right, slot++);
} else {
--right;
}
rightBucket = getBucket(right, d);
}
if (left < bucketFrom && right >= bucketTo) {
swap(left++, right--);
} else {
assert left == bucketFrom;
assert right == bucketTo - 1;
break;
}
}
}
}