blob: fcf72b75e1f741f78c316cf827524753a1555ce1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.suggest.fst;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.search.suggest.InputIterator;
import org.apache.lucene.search.suggest.Lookup;
import org.apache.lucene.search.suggest.fst.FSTCompletion.Completion;
import org.apache.lucene.search.suggest.tst.TSTLookup;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ByteArrayDataOutput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.OfflineSorter;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.NoOutputs;
/**
* An adapter from {@link Lookup} API to {@link FSTCompletion}.
*
* <p>This adapter differs from {@link FSTCompletion} in that it attempts to discretize any
* "weights" as passed from in {@link InputIterator#weight()} to match the number of buckets. For
* the rationale for bucketing, see {@link FSTCompletion}.
*
* <p><b>Note:</b>Discretization requires an additional sorting pass.
*
* <p>The range of weights for bucketing/ discretization is determined by sorting the input by
* weight and then dividing into equal ranges. Then, scores within each range are assigned to that
* bucket.
*
* <p>Note that this means that even large differences in weights may be lost during automaton
* construction, but the overall distinction between "classes" of weights will be preserved
* regardless of the distribution of weights.
*
* <p>For fine-grained control over which weights are assigned to which buckets, use {@link
* FSTCompletion} directly or {@link TSTLookup}, for example.
*
* @see FSTCompletion
* @lucene.experimental
*/
public class FSTCompletionLookup extends Lookup implements Accountable {
/**
* An invalid bucket count if we're creating an object of this class from an existing FST.
*
* @see #FSTCompletionLookup(Directory, String, FSTCompletion, boolean)
*/
private static int INVALID_BUCKETS_COUNT = -1;
/**
* Shared tail length for conflating in the created automaton. Setting this to larger values
* ({@link Integer#MAX_VALUE}) will create smaller (or minimal) automata at the cost of RAM for
* keeping nodes hash in the {@link FST}.
*
* <p>Empirical pick.
*/
private static final int sharedTailLength = 5;
private final Directory tempDir;
private final String tempFileNamePrefix;
private int buckets;
private boolean exactMatchFirst;
/** Automaton used for completions with higher weights reordering. */
private FSTCompletion higherWeightsCompletion;
/** Automaton used for normal completions. */
private FSTCompletion normalCompletion;
/** Number of entries the lookup was built with */
private long count = 0;
/** This constructor should only be used to read a previously saved suggester. */
public FSTCompletionLookup() {
this(null, null);
}
/**
* This constructor prepares for creating a suggested FST using the {@link #build(InputIterator)}
* method. The number of weight discretization buckets is set to {@link
* FSTCompletion#DEFAULT_BUCKETS} and exact matches are promoted to the top of the suggestions
* list.
*/
public FSTCompletionLookup(Directory tempDir, String tempFileNamePrefix) {
this(tempDir, tempFileNamePrefix, FSTCompletion.DEFAULT_BUCKETS, true);
}
/**
* This constructor prepares for creating a suggested FST using the {@link #build(InputIterator)}
* method.
*
* @param buckets The number of weight discretization buckets (see {@link FSTCompletion} for
* details).
* @param exactMatchFirst If <code>true</code> exact matches are promoted to the top of the
* suggestions list. Otherwise they appear in the order of discretized weight and alphabetical
* within the bucket.
*/
public FSTCompletionLookup(
Directory tempDir, String tempFileNamePrefix, int buckets, boolean exactMatchFirst) {
this.buckets = buckets;
this.exactMatchFirst = exactMatchFirst;
this.tempDir = tempDir;
this.tempFileNamePrefix = tempFileNamePrefix;
}
/**
* This constructor takes a pre-built automaton.
*
* @param completion An instance of {@link FSTCompletion}.
* @param exactMatchFirst If <code>true</code> exact matches are promoted to the top of the
* suggestions list. Otherwise they appear in the order of discretized weight and alphabetical
* within the bucket.
*/
public FSTCompletionLookup(
Directory tempDir,
String tempFileNamePrefix,
FSTCompletion completion,
boolean exactMatchFirst) {
this(tempDir, tempFileNamePrefix, INVALID_BUCKETS_COUNT, exactMatchFirst);
this.normalCompletion = new FSTCompletion(completion.getFST(), false, exactMatchFirst);
this.higherWeightsCompletion = new FSTCompletion(completion.getFST(), true, exactMatchFirst);
}
@Override
public void build(InputIterator iterator) throws IOException {
if (iterator.hasPayloads()) {
throw new IllegalArgumentException("this suggester doesn't support payloads");
}
if (iterator.hasContexts()) {
throw new IllegalArgumentException("this suggester doesn't support contexts");
}
OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix);
ExternalRefSorter externalSorter = new ExternalRefSorter(sorter);
IndexOutput tempInput =
tempDir.createTempOutput(tempFileNamePrefix, "input", IOContext.DEFAULT);
String tempSortedFileName = null;
OfflineSorter.ByteSequencesWriter writer = new OfflineSorter.ByteSequencesWriter(tempInput);
OfflineSorter.ByteSequencesReader reader = null;
// Push floats up front before sequences to sort them. For now, assume they are non-negative.
// If negative floats are allowed some trickery needs to be done to find their byte order.
count = 0;
try {
byte[] buffer = new byte[0];
ByteArrayDataOutput output = new ByteArrayDataOutput(buffer);
BytesRef spare;
int inputLineCount = 0;
while ((spare = iterator.next()) != null) {
if (spare.length + 4 >= buffer.length) {
buffer = ArrayUtil.grow(buffer, spare.length + 4);
}
output.reset(buffer);
output.writeInt(encodeWeight(iterator.weight()));
output.writeBytes(spare.bytes, spare.offset, spare.length);
writer.write(buffer, 0, output.getPosition());
inputLineCount++;
}
CodecUtil.writeFooter(tempInput);
writer.close();
// We don't know the distribution of scores and we need to bucket them, so we'll sort
// and divide into equal buckets.
tempSortedFileName = sorter.sort(tempInput.getName());
tempDir.deleteFile(tempInput.getName());
FSTCompletionBuilder builder =
new FSTCompletionBuilder(buckets, externalSorter, sharedTailLength);
reader =
new OfflineSorter.ByteSequencesReader(
tempDir.openChecksumInput(tempSortedFileName, IOContext.READONCE),
tempSortedFileName);
long line = 0;
int previousBucket = 0;
int previousScore = 0;
ByteArrayDataInput input = new ByteArrayDataInput();
BytesRef tmp2 = new BytesRef();
while (true) {
BytesRef scratch = reader.next();
if (scratch == null) {
break;
}
input.reset(scratch.bytes, scratch.offset, scratch.length);
int currentScore = input.readInt();
int bucket;
if (line > 0 && currentScore == previousScore) {
bucket = previousBucket;
} else {
bucket = (int) (line * buckets / inputLineCount);
}
previousScore = currentScore;
previousBucket = bucket;
// Only append the input, discard the weight.
tmp2.bytes = scratch.bytes;
tmp2.offset = scratch.offset + input.getPosition();
tmp2.length = scratch.length - input.getPosition();
builder.add(tmp2, bucket);
line++;
count++;
}
// The two FSTCompletions share the same automaton.
this.higherWeightsCompletion = builder.build();
this.normalCompletion =
new FSTCompletion(higherWeightsCompletion.getFST(), false, exactMatchFirst);
} finally {
IOUtils.closeWhileHandlingException(reader, writer, externalSorter);
IOUtils.deleteFilesIgnoringExceptions(tempDir, tempInput.getName(), tempSortedFileName);
}
}
/** weight -&gt; cost */
private static int encodeWeight(long value) {
if (value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) {
throw new UnsupportedOperationException("cannot encode value: " + value);
}
return (int) value;
}
@Override
public List<LookupResult> lookup(
CharSequence key, Set<BytesRef> contexts, boolean higherWeightsFirst, int num) {
if (contexts != null) {
throw new IllegalArgumentException("this suggester doesn't support contexts");
}
final List<Completion> completions;
if (higherWeightsFirst) {
completions = higherWeightsCompletion.lookup(key, num);
} else {
completions = normalCompletion.lookup(key, num);
}
final ArrayList<LookupResult> results = new ArrayList<>(completions.size());
CharsRefBuilder spare = new CharsRefBuilder();
for (Completion c : completions) {
spare.copyUTF8Bytes(c.utf8);
results.add(new LookupResult(spare.toString(), c.bucket));
}
return results;
}
/**
* Returns the bucket (weight) as a Long for the provided key if it exists, otherwise null if it
* does not.
*/
public Object get(CharSequence key) {
final int bucket = normalCompletion.getBucket(key);
return bucket == -1 ? null : Long.valueOf(bucket);
}
@Override
public synchronized boolean store(DataOutput output) throws IOException {
output.writeVLong(count);
if (normalCompletion == null || normalCompletion.getFST() == null) {
return false;
}
normalCompletion.getFST().save(output, output);
return true;
}
@Override
public synchronized boolean load(DataInput input) throws IOException {
count = input.readVLong();
this.higherWeightsCompletion =
new FSTCompletion(new FST<>(input, input, NoOutputs.getSingleton()));
this.normalCompletion =
new FSTCompletion(higherWeightsCompletion.getFST(), false, exactMatchFirst);
return true;
}
@Override
public long ramBytesUsed() {
long mem =
RamUsageEstimator.shallowSizeOf(this)
+ RamUsageEstimator.shallowSizeOf(normalCompletion)
+ RamUsageEstimator.shallowSizeOf(higherWeightsCompletion);
if (normalCompletion != null) {
mem += normalCompletion.getFST().ramBytesUsed();
}
if (higherWeightsCompletion != null
&& (normalCompletion == null
|| normalCompletion.getFST() != higherWeightsCompletion.getFST())) {
// the fst should be shared between the 2 completion instances, don't count it twice
mem += higherWeightsCompletion.getFST().ramBytesUsed();
}
return mem;
}
@Override
public Collection<Accountable> getChildResources() {
List<Accountable> resources = new ArrayList<>();
if (normalCompletion != null) {
resources.add(Accountables.namedAccountable("fst", normalCompletion.getFST()));
}
if (higherWeightsCompletion != null
&& (normalCompletion == null
|| normalCompletion.getFST() != higherWeightsCompletion.getFST())) {
resources.add(
Accountables.namedAccountable("higher weights fst", higherWeightsCompletion.getFST()));
}
return Collections.unmodifiableList(resources);
}
@Override
public long getCount() {
return count;
}
}