| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search.suggest.fst; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.List; |
| import java.util.Set; |
| |
| import org.apache.lucene.search.suggest.InputIterator; |
| import org.apache.lucene.search.suggest.Lookup; |
| import org.apache.lucene.search.suggest.SortedInputIterator; |
| import org.apache.lucene.store.ByteArrayDataInput; |
| import org.apache.lucene.store.ByteArrayDataOutput; |
| import org.apache.lucene.store.DataInput; |
| import org.apache.lucene.store.DataOutput; |
| import org.apache.lucene.store.Directory; |
| import org.apache.lucene.util.Accountable; |
| import org.apache.lucene.util.Accountables; |
| import org.apache.lucene.util.ArrayUtil; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.BytesRefBuilder; |
| import org.apache.lucene.util.CharsRefBuilder; |
| import org.apache.lucene.util.IntsRefBuilder; |
| import org.apache.lucene.util.OfflineSorter.ByteSequencesWriter; |
| import org.apache.lucene.util.fst.Builder; |
| import org.apache.lucene.util.fst.FST; |
| import org.apache.lucene.util.fst.FST.Arc; |
| import org.apache.lucene.util.fst.FST.BytesReader; |
| import org.apache.lucene.util.fst.PositiveIntOutputs; |
| import org.apache.lucene.util.fst.Util; |
| import org.apache.lucene.util.fst.Util.Result; |
| import org.apache.lucene.util.fst.Util.TopResults; |
| |
| /** |
| * Suggester based on a weighted FST: it first traverses the prefix, |
| * then walks the <i>n</i> shortest paths to retrieve top-ranked |
| * suggestions. |
| * <p> |
| * <b>NOTE</b>: |
| * Input weights must be between 0 and {@link Integer#MAX_VALUE}, any |
| * other values will be rejected. |
| * |
| * @lucene.experimental |
| */ |
| // redundant 'implements Accountable' to workaround javadocs bugs |
| public class WFSTCompletionLookup extends Lookup implements Accountable { |
| |
| /** |
| * FST<Long>, weights are encoded as costs: (Integer.MAX_VALUE-weight) |
| */ |
| // NOTE: like FSTSuggester, this is really a WFSA, if you want to |
| // customize the code to add some output you should use PairOutputs. |
| private FST<Long> fst = null; |
| |
| /** |
| * True if exact match suggestions should always be returned first. |
| */ |
| private final boolean exactFirst; |
| |
| /** Number of entries the lookup was built with */ |
| private long count = 0; |
| |
| private final Directory tempDir; |
| private final String tempFileNamePrefix; |
| |
| /** |
| * Calls {@link #WFSTCompletionLookup(Directory,String,boolean) WFSTCompletionLookup(null,null,true)} |
| */ |
| public WFSTCompletionLookup(Directory tempDir, String tempFileNamePrefix) { |
| this(tempDir, tempFileNamePrefix, true); |
| } |
| |
| /** |
| * Creates a new suggester. |
| * |
| * @param exactFirst <code>true</code> if suggestions that match the |
| * prefix exactly should always be returned first, regardless |
| * of score. This has no performance impact, but could result |
| * in low-quality suggestions. |
| */ |
| public WFSTCompletionLookup(Directory tempDir, String tempFileNamePrefix, boolean exactFirst) { |
| this.exactFirst = exactFirst; |
| this.tempDir = tempDir; |
| this.tempFileNamePrefix = tempFileNamePrefix; |
| } |
| |
| @Override |
| public void build(InputIterator iterator) throws IOException { |
| if (iterator.hasPayloads()) { |
| throw new IllegalArgumentException("this suggester doesn't support payloads"); |
| } |
| if (iterator.hasContexts()) { |
| throw new IllegalArgumentException("this suggester doesn't support contexts"); |
| } |
| count = 0; |
| BytesRef scratch = new BytesRef(); |
| InputIterator iter = new WFSTInputIterator(tempDir, tempFileNamePrefix, iterator); |
| IntsRefBuilder scratchInts = new IntsRefBuilder(); |
| BytesRefBuilder previous = null; |
| PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(); |
| Builder<Long> builder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs); |
| while ((scratch = iter.next()) != null) { |
| long cost = iter.weight(); |
| |
| if (previous == null) { |
| previous = new BytesRefBuilder(); |
| } else if (scratch.equals(previous.get())) { |
| continue; // for duplicate suggestions, the best weight is actually |
| // added |
| } |
| Util.toIntsRef(scratch, scratchInts); |
| builder.add(scratchInts.get(), cost); |
| previous.copyBytes(scratch); |
| count++; |
| } |
| fst = builder.finish(); |
| } |
| |
| |
| @Override |
| public boolean store(DataOutput output) throws IOException { |
| output.writeVLong(count); |
| if (fst == null) { |
| return false; |
| } |
| fst.save(output, output); |
| return true; |
| } |
| |
| @Override |
| public boolean load(DataInput input) throws IOException { |
| count = input.readVLong(); |
| this.fst = new FST<>(input, input, PositiveIntOutputs.getSingleton()); |
| return true; |
| } |
| |
| @Override |
| public List<LookupResult> lookup(CharSequence key, Set<BytesRef> contexts, boolean onlyMorePopular, int num) { |
| if (contexts != null) { |
| throw new IllegalArgumentException("this suggester doesn't support contexts"); |
| } |
| assert num > 0; |
| |
| if (onlyMorePopular) { |
| throw new IllegalArgumentException("this suggester only works with onlyMorePopular=false"); |
| } |
| |
| if (fst == null) { |
| return Collections.emptyList(); |
| } |
| |
| BytesRefBuilder scratch = new BytesRefBuilder(); |
| scratch.copyChars(key); |
| int prefixLength = scratch.length(); |
| Arc<Long> arc = new Arc<>(); |
| |
| // match the prefix portion exactly |
| Long prefixOutput = null; |
| try { |
| prefixOutput = lookupPrefix(scratch.get(), arc); |
| } catch (IOException bogus) { throw new RuntimeException(bogus); } |
| |
| if (prefixOutput == null) { |
| return Collections.emptyList(); |
| } |
| |
| List<LookupResult> results = new ArrayList<>(num); |
| CharsRefBuilder spare = new CharsRefBuilder(); |
| if (exactFirst && arc.isFinal()) { |
| spare.copyUTF8Bytes(scratch.get()); |
| results.add(new LookupResult(spare.toString(), decodeWeight(prefixOutput + arc.nextFinalOutput()))); |
| if (--num == 0) { |
| return results; // that was quick |
| } |
| } |
| |
| // complete top-N |
| TopResults<Long> completions = null; |
| try { |
| completions = Util.shortestPaths(fst, arc, prefixOutput, weightComparator, num, !exactFirst); |
| assert completions.isComplete; |
| } catch (IOException bogus) { |
| throw new RuntimeException(bogus); |
| } |
| |
| BytesRefBuilder suffix = new BytesRefBuilder(); |
| for (Result<Long> completion : completions) { |
| scratch.setLength(prefixLength); |
| // append suffix |
| Util.toBytesRef(completion.input, suffix); |
| scratch.append(suffix); |
| spare.copyUTF8Bytes(scratch.get()); |
| results.add(new LookupResult(spare.toString(), decodeWeight(completion.output))); |
| } |
| return results; |
| } |
| |
| private Long lookupPrefix(BytesRef scratch, Arc<Long> arc) throws /*Bogus*/IOException { |
| assert 0 == fst.outputs.getNoOutput().longValue(); |
| long output = 0; |
| BytesReader bytesReader = fst.getBytesReader(); |
| |
| fst.getFirstArc(arc); |
| |
| byte[] bytes = scratch.bytes; |
| int pos = scratch.offset; |
| int end = pos + scratch.length; |
| while (pos < end) { |
| if (fst.findTargetArc(bytes[pos++] & 0xff, arc, arc, bytesReader) == null) { |
| return null; |
| } else { |
| output += arc.output().longValue(); |
| } |
| } |
| |
| return output; |
| } |
| |
| /** |
| * Returns the weight associated with an input string, |
| * or null if it does not exist. |
| */ |
| public Object get(CharSequence key) { |
| if (fst == null) { |
| return null; |
| } |
| Arc<Long> arc = new Arc<>(); |
| Long result = null; |
| try { |
| result = lookupPrefix(new BytesRef(key), arc); |
| } catch (IOException bogus) { throw new RuntimeException(bogus); } |
| if (result == null || !arc.isFinal()) { |
| return null; |
| } else { |
| return Integer.valueOf(decodeWeight(result + arc.nextFinalOutput())); |
| } |
| } |
| |
| /** cost -> weight */ |
| private static int decodeWeight(long encoded) { |
| return (int)(Integer.MAX_VALUE - encoded); |
| } |
| |
| /** weight -> cost */ |
| private static int encodeWeight(long value) { |
| if (value < 0 || value > Integer.MAX_VALUE) { |
| throw new UnsupportedOperationException("cannot encode value: " + value); |
| } |
| return Integer.MAX_VALUE - (int)value; |
| } |
| |
| private static final class WFSTInputIterator extends SortedInputIterator { |
| |
| WFSTInputIterator(Directory tempDir, String tempFileNamePrefix, InputIterator source) throws IOException { |
| super(tempDir, tempFileNamePrefix, source); |
| assert source.hasPayloads() == false; |
| } |
| |
| @Override |
| protected void encode(ByteSequencesWriter writer, ByteArrayDataOutput output, byte[] buffer, BytesRef spare, BytesRef payload, Set<BytesRef> contexts, long weight) throws IOException { |
| if (spare.length + 4 >= buffer.length) { |
| buffer = ArrayUtil.grow(buffer, spare.length + 4); |
| } |
| output.reset(buffer); |
| output.writeBytes(spare.bytes, spare.offset, spare.length); |
| output.writeInt(encodeWeight(weight)); |
| writer.write(buffer, 0, output.getPosition()); |
| } |
| |
| @Override |
| protected long decode(BytesRef scratch, ByteArrayDataInput tmpInput) { |
| scratch.length -= 4; // int |
| // skip suggestion: |
| tmpInput.reset(scratch.bytes, scratch.offset+scratch.length, 4); |
| return tmpInput.readInt(); |
| } |
| } |
| |
| static final Comparator<Long> weightComparator = new Comparator<Long> () { |
| @Override |
| public int compare(Long left, Long right) { |
| return left.compareTo(right); |
| } |
| }; |
| |
| /** Returns byte size of the underlying FST. */ |
| @Override |
| public long ramBytesUsed() { |
| return (fst == null) ? 0 : fst.ramBytesUsed(); |
| } |
| |
| @Override |
| public Collection<Accountable> getChildResources() { |
| if (fst == null) { |
| return Collections.emptyList(); |
| } else { |
| return Collections.singleton(Accountables.namedAccountable("fst", fst)); |
| } |
| } |
| |
| @Override |
| public long getCount() { |
| return count; |
| } |
| } |