blob: 8e9215d10567ea2929487b657b61bec3d6095a4c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.minhash;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
/**
* Generate min hash tokens from an incoming stream of tokens. The incoming tokens would typically
* be 5 word shingles.
*
* <p>The number of hashes used and the number of minimum values for each hash can be set. You could
* have 1 hash and keep the 100 lowest values or 100 hashes and keep the lowest one for each. Hashes
* can also be bucketed in ranges over the 128-bit hash space,
*
* <p>A 128-bit hash is used internally. 5 word shingles from 10e5 words generate 10e25 combinations
* So a 64 bit hash would have collisions (1.8e19)
*
* <p>When using different hashes 32 bits are used for the hash position leaving scope for 8e28
* unique hashes. A single hash will use all 128 bits.
*/
public class MinHashFilter extends TokenFilter {
private static final int HASH_CACHE_SIZE = 512;
private static final LongPair[] cachedIntHashes = new LongPair[HASH_CACHE_SIZE];
public static final int DEFAULT_HASH_COUNT = 1;
public static final int DEFAULT_HASH_SET_SIZE = 1;
public static final int DEFAULT_BUCKET_COUNT = 512;
static final String MIN_HASH_TYPE = "MIN_HASH";
private final List<List<FixedSizeTreeSet<LongPair>>> minHashSets;
private int hashSetSize = DEFAULT_HASH_SET_SIZE;
private int bucketCount = DEFAULT_BUCKET_COUNT;
private int hashCount = DEFAULT_HASH_COUNT;
private boolean requiresInitialisation = true;
private State endState;
private int hashPosition = -1;
private int bucketPosition = -1;
private long bucketSize;
private final boolean withRotation;
private int endOffset;
private boolean exhausted = false;
private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class);
private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class);
private final TypeAttribute typeAttribute = addAttribute(TypeAttribute.class);
private final PositionIncrementAttribute posIncAttribute =
addAttribute(PositionIncrementAttribute.class);
private final PositionLengthAttribute posLenAttribute =
addAttribute(PositionLengthAttribute.class);
static {
for (int i = 0; i < HASH_CACHE_SIZE; i++) {
cachedIntHashes[i] = new LongPair();
murmurhash3_x64_128(getBytes(i), 0, 4, 0, cachedIntHashes[i]);
}
}
static byte[] getBytes(int i) {
byte[] answer = new byte[4];
answer[3] = (byte) (i);
answer[2] = (byte) (i >> 8);
answer[1] = (byte) (i >> 16);
answer[0] = (byte) (i >> 24);
return answer;
}
/**
* create a MinHash filter
*
* @param input the token stream
* @param hashCount the no. of hashes
* @param bucketCount the no. of buckets for hashing
* @param hashSetSize the no. of min hashes to keep
* @param withRotation whether rotate or not hashes while incrementing tokens
*/
public MinHashFilter(
TokenStream input, int hashCount, int bucketCount, int hashSetSize, boolean withRotation) {
super(input);
if (hashCount <= 0) {
throw new IllegalArgumentException("hashCount must be greater than zero");
}
if (bucketCount <= 0) {
throw new IllegalArgumentException("bucketCount must be greater than zero");
}
if (hashSetSize <= 0) {
throw new IllegalArgumentException("hashSetSize must be greater than zero");
}
this.hashCount = hashCount;
this.bucketCount = bucketCount;
this.hashSetSize = hashSetSize;
this.withRotation = withRotation;
this.bucketSize = (1L << 32) / bucketCount;
if ((1L << 32) % bucketCount != 0) {
bucketSize++;
}
minHashSets = new ArrayList<>(this.hashCount);
for (int i = 0; i < this.hashCount; i++) {
ArrayList<FixedSizeTreeSet<LongPair>> buckets = new ArrayList<>(this.bucketCount);
minHashSets.add(buckets);
for (int j = 0; j < this.bucketCount; j++) {
FixedSizeTreeSet<LongPair> minSet = new FixedSizeTreeSet<>(this.hashSetSize);
buckets.add(minSet);
}
}
doRest();
}
@Override
public final boolean incrementToken() throws IOException {
// Pull the underlying stream of tokens
// Hash each token found
// Generate the required number of variants of this hash
// Keep the minimum hash value found so far of each variant
int positionIncrement = 0;
if (requiresInitialisation) {
requiresInitialisation = false;
boolean found = false;
// First time through so we pull and hash everything
while (input.incrementToken()) {
found = true;
String current = new String(termAttribute.buffer(), 0, termAttribute.length());
for (int i = 0; i < hashCount; i++) {
byte[] bytes = current.getBytes("UTF-16LE");
LongPair hash = new LongPair();
murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash);
LongPair rehashed = combineOrdered(hash, getIntHash(i));
minHashSets.get(i).get((int) ((rehashed.val2 >>> 32) / bucketSize)).add(rehashed);
}
endOffset = offsetAttribute.endOffset();
}
exhausted = true;
input.end();
// We need the end state so an underlying shingle filter can have its state restored
// correctly.
endState = captureState();
if (!found) {
return false;
}
positionIncrement = 1;
// fix up any wrap around bucket values. ...
if (withRotation && (hashSetSize == 1)) {
for (int hashLoop = 0; hashLoop < hashCount; hashLoop++) {
for (int bucketLoop = 0; bucketLoop < bucketCount; bucketLoop++) {
if (minHashSets.get(hashLoop).get(bucketLoop).size() == 0) {
for (int bucketOffset = 1; bucketOffset < bucketCount; bucketOffset++) {
if (minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount).size()
> 0) {
LongPair replacementHash =
minHashSets
.get(hashLoop)
.get((bucketLoop + bucketOffset) % bucketCount)
.first();
minHashSets.get(hashLoop).get(bucketLoop).add(replacementHash);
break;
}
}
}
}
}
}
}
clearAttributes();
while (hashPosition < hashCount) {
if (hashPosition == -1) {
hashPosition++;
} else {
while (bucketPosition < bucketCount) {
if (bucketPosition == -1) {
bucketPosition++;
} else {
LongPair hash = minHashSets.get(hashPosition).get(bucketPosition).pollFirst();
if (hash != null) {
termAttribute.setEmpty();
if (hashCount > 1) {
termAttribute.append(int0(hashPosition));
termAttribute.append(int1(hashPosition));
}
long high = hash.val2;
termAttribute.append(long0(high));
termAttribute.append(long1(high));
termAttribute.append(long2(high));
termAttribute.append(long3(high));
long low = hash.val1;
termAttribute.append(long0(low));
termAttribute.append(long1(low));
if (hashCount == 1) {
termAttribute.append(long2(low));
termAttribute.append(long3(low));
}
posIncAttribute.setPositionIncrement(positionIncrement);
offsetAttribute.setOffset(0, endOffset);
typeAttribute.setType(MIN_HASH_TYPE);
posLenAttribute.setPositionLength(1);
return true;
} else {
bucketPosition++;
}
}
}
bucketPosition = -1;
hashPosition++;
}
}
return false;
}
private static LongPair getIntHash(int i) {
if (i < HASH_CACHE_SIZE) {
return cachedIntHashes[i];
} else {
LongPair answer = new LongPair();
murmurhash3_x64_128(getBytes(i), 0, 4, 0, answer);
return answer;
}
}
@Override
public void end() throws IOException {
if (!exhausted) {
input.end();
}
restoreState(endState);
}
@Override
public void reset() throws IOException {
super.reset();
doRest();
}
private void doRest() {
for (int i = 0; i < hashCount; i++) {
for (int j = 0; j < bucketCount; j++) {
minHashSets.get(i).get(j).clear();
}
}
endState = null;
hashPosition = -1;
bucketPosition = -1;
requiresInitialisation = true;
exhausted = false;
}
private static char long0(long x) {
return (char) (x >> 48);
}
private static char long1(long x) {
return (char) (x >> 32);
}
private static char long2(long x) {
return (char) (x >> 16);
}
private static char long3(long x) {
return (char) (x);
}
private static char int0(int x) {
return (char) (x >> 16);
}
private static char int1(int x) {
return (char) (x);
}
static boolean isLessThanUnsigned(long n1, long n2) {
return (n1 < n2) ^ ((n1 < 0) != (n2 < 0));
}
static class FixedSizeTreeSet<E extends Comparable<E>> extends TreeSet<E> {
/** */
private static final long serialVersionUID = -8237117170340299630L;
private final int capacity;
FixedSizeTreeSet() {
this(20);
}
FixedSizeTreeSet(int capacity) {
super();
this.capacity = capacity;
}
@Override
public boolean add(final E toAdd) {
if (capacity <= size()) {
final E lastElm = last();
if (toAdd.compareTo(lastElm) > -1) {
return false;
} else {
pollLast();
}
}
return super.add(toAdd);
}
}
private static LongPair combineOrdered(LongPair... hashCodes) {
LongPair result = new LongPair();
for (LongPair hashCode : hashCodes) {
result.val1 = result.val1 * 37 + hashCode.val1;
result.val2 = result.val2 * 37 + hashCode.val2;
}
return result;
}
/** 128 bits of state */
static final class LongPair implements Comparable<LongPair> {
public long val1;
public long val2;
/*
* (non-Javadoc)
*
* @see java.lang.Comparable#compareTo(java.lang.Object)
*/
@Override
public int compareTo(LongPair other) {
if (isLessThanUnsigned(val2, other.val2)) {
return -1;
} else if (val2 == other.val2) {
if (isLessThanUnsigned(val1, other.val1)) {
return -1;
} else if (val1 == other.val1) {
return 0;
} else {
return 1;
}
} else {
return 1;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LongPair longPair = (LongPair) o;
return val1 == longPair.val1 && val2 == longPair.val2;
}
@Override
public int hashCode() {
int result = (int) (val1 ^ (val1 >>> 32));
result = 31 * result + (int) (val2 ^ (val2 >>> 32));
return result;
}
}
/** Gets a long from a byte buffer in little endian byte order. */
private static long getLongLittleEndian(byte[] buf, int offset) {
return ((long) buf[offset + 7] << 56) // no mask needed
| ((buf[offset + 6] & 0xffL) << 48)
| ((buf[offset + 5] & 0xffL) << 40)
| ((buf[offset + 4] & 0xffL) << 32)
| ((buf[offset + 3] & 0xffL) << 24)
| ((buf[offset + 2] & 0xffL) << 16)
| ((buf[offset + 1] & 0xffL) << 8)
| ((buf[offset] & 0xffL)); // no shift needed
}
/** Returns the MurmurHash3_x64_128 hash, placing the result in "out". */
@SuppressWarnings("fallthrough") // the huge switch is designed to use fall through into cases!
static void murmurhash3_x64_128(byte[] key, int offset, int len, int seed, LongPair out) {
// The original algorithm does have a 32 bit unsigned seed.
// We have to mask to match the behavior of the unsigned types and prevent sign extension.
long h1 = seed & 0x00000000FFFFFFFFL;
long h2 = seed & 0x00000000FFFFFFFFL;
final long c1 = 0x87c37b91114253d5L;
final long c2 = 0x4cf5ad432745937fL;
int roundedEnd = offset + (len & 0xFFFFFFF0); // round down to 16 byte block
for (int i = offset; i < roundedEnd; i += 16) {
long k1 = getLongLittleEndian(key, i);
long k2 = getLongLittleEndian(key, i + 8);
k1 *= c1;
k1 = Long.rotateLeft(k1, 31);
k1 *= c2;
h1 ^= k1;
h1 = Long.rotateLeft(h1, 27);
h1 += h2;
h1 = h1 * 5 + 0x52dce729;
k2 *= c2;
k2 = Long.rotateLeft(k2, 33);
k2 *= c1;
h2 ^= k2;
h2 = Long.rotateLeft(h2, 31);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
}
long k1 = 0;
long k2 = 0;
switch (len & 15) {
case 15:
k2 = (key[roundedEnd + 14] & 0xffL) << 48;
case 14:
k2 |= (key[roundedEnd + 13] & 0xffL) << 40;
case 13:
k2 |= (key[roundedEnd + 12] & 0xffL) << 32;
case 12:
k2 |= (key[roundedEnd + 11] & 0xffL) << 24;
case 11:
k2 |= (key[roundedEnd + 10] & 0xffL) << 16;
case 10:
k2 |= (key[roundedEnd + 9] & 0xffL) << 8;
case 9:
k2 |= (key[roundedEnd + 8] & 0xffL);
k2 *= c2;
k2 = Long.rotateLeft(k2, 33);
k2 *= c1;
h2 ^= k2;
case 8:
k1 = ((long) key[roundedEnd + 7]) << 56;
case 7:
k1 |= (key[roundedEnd + 6] & 0xffL) << 48;
case 6:
k1 |= (key[roundedEnd + 5] & 0xffL) << 40;
case 5:
k1 |= (key[roundedEnd + 4] & 0xffL) << 32;
case 4:
k1 |= (key[roundedEnd + 3] & 0xffL) << 24;
case 3:
k1 |= (key[roundedEnd + 2] & 0xffL) << 16;
case 2:
k1 |= (key[roundedEnd + 1] & 0xffL) << 8;
case 1:
k1 |= (key[roundedEnd] & 0xffL);
k1 *= c1;
k1 = Long.rotateLeft(k1, 31);
k1 *= c2;
h1 ^= k1;
}
// ----------
// finalization
h1 ^= len;
h2 ^= len;
h1 += h2;
h2 += h1;
h1 = fmix64(h1);
h2 = fmix64(h2);
h1 += h2;
h2 += h1;
out.val1 = h1;
out.val2 = h2;
}
private static long fmix64(long k) {
k ^= k >>> 33;
k *= 0xff51afd7ed558ccdL;
k ^= k >>> 33;
k *= 0xc4ceb9fe1a85ec53L;
k ^= k >>> 33;
return k;
}
}