blob: 83baabcd14b9c783daa2843ec707ef0d0d132728 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.suggest.document;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;
import org.apache.lucene.analysis.miscellaneous.ConcatenateGraphFilter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.fst.Util;
/**
* A {@link CompletionQuery} that matches documents specified by
* a wrapped {@link CompletionQuery} supporting boosting and/or filtering
* by specified contexts.
* <p>
* Use this query against {@link ContextSuggestField}
* <p>
* Example of using a {@link CompletionQuery} with boosted
* contexts:
* <pre class="prettyprint">
* CompletionQuery completionQuery = ...;
* ContextQuery query = new ContextQuery(completionQuery);
* query.addContext("context1", 2);
* query.addContext("context2", 1);
* </pre>
* <p>
* NOTE:
* <ul>
* <li>
* This query can be constructed with
* {@link PrefixCompletionQuery}, {@link RegexCompletionQuery}
* or {@link FuzzyCompletionQuery} query.
* </li>
* <li>
* To suggest across all contexts, use {@link #addAllContexts()}.
* When no context is added, the default behaviour is to suggest across
* all contexts.
* </li>
* <li>
* To apply the same boost to multiple contexts sharing the same prefix,
* Use {@link #addContext(CharSequence, float, boolean)} with the common
* context prefix, boost and set <code>exact</code> to false.
* <li>
* Using this query against a {@link SuggestField} (not context enabled),
* would yield results ignoring any context filtering/boosting
* </li>
* </ul>
*
* @lucene.experimental
*/
public class ContextQuery extends CompletionQuery implements Accountable {
private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(ContextQuery.class);
private IntsRefBuilder scratch = new IntsRefBuilder();
private Map<IntsRef, ContextMetaData> contexts;
private boolean matchAllContexts = false;
/** Inner completion query */
protected CompletionQuery innerQuery;
private long ramBytesUsed;
/**
* Constructs a context completion query that matches
* documents specified by <code>query</code>.
* <p>
* Use {@link #addContext(CharSequence, float, boolean)}
* to add context(s) with boost
*/
public ContextQuery(CompletionQuery query) {
super(query.getTerm(), query.getFilter());
if (query instanceof ContextQuery) {
throw new IllegalArgumentException("'query' parameter must not be of type "
+ this.getClass().getSimpleName());
}
this.innerQuery = query;
contexts = new HashMap<>();
updateRamBytesUsed();
}
private void updateRamBytesUsed() {
ramBytesUsed = BASE_RAM_BYTES +
RamUsageEstimator.sizeOfObject(contexts) +
RamUsageEstimator.sizeOfObject(innerQuery, RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED);
}
/**
* Adds an exact context with default boost of 1
*/
public void addContext(CharSequence context) {
addContext(context, 1f, true);
}
/**
* Adds an exact context with boost
*/
public void addContext(CharSequence context, float boost) {
addContext(context, boost, true);
}
/**
* Adds a context with boost, set <code>exact</code> to false
* if the context is a prefix of any indexed contexts
*/
public void addContext(CharSequence context, float boost, boolean exact) {
if (boost < 0f) {
throw new IllegalArgumentException("'boost' must be >= 0");
}
for (int i = 0; i < context.length(); i++) {
if (ContextSuggestField.CONTEXT_SEPARATOR == context.charAt(i)) {
throw new IllegalArgumentException("Illegal value [" + context + "] UTF-16 codepoint [0x"
+ Integer.toHexString((int) context.charAt(i))+ "] at position " + i + " is a reserved character");
}
}
contexts.put(IntsRef.deepCopyOf(Util.toIntsRef(new BytesRef(context), scratch)), new ContextMetaData(boost, exact));
updateRamBytesUsed();
}
/**
* Add all contexts with a boost of 1f
*/
public void addAllContexts() {
matchAllContexts = true;
}
@Override
public String toString(String field) {
StringBuilder buffer = new StringBuilder();
BytesRefBuilder scratch = new BytesRefBuilder();
for (Map.Entry<IntsRef, ContextMetaData> entry : contexts.entrySet()) {
if (buffer.length() != 0) {
buffer.append(",");
} else {
buffer.append("contexts");
buffer.append(":[");
}
buffer.append(Util.toBytesRef(entry.getKey(), scratch).utf8ToString());
ContextMetaData metaData = entry.getValue();
if (metaData.exact == false) {
buffer.append("*");
}
if (metaData.boost != 0) {
buffer.append("^");
buffer.append(Float.toString(metaData.boost));
}
}
if (buffer.length() != 0) {
buffer.append("]");
buffer.append(",");
}
return buffer.toString() + innerQuery.toString(field);
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
final CompletionWeight innerWeight = ((CompletionWeight) innerQuery.createWeight(searcher, scoreMode, boost));
final Automaton innerAutomaton = innerWeight.getAutomaton();
// If the inner automaton matches nothing, then we return an empty weight to avoid
// traversing all contexts during scoring.
if (Operations.isEmpty(innerAutomaton)) {
return new CompletionWeight(this, Automata.makeEmpty());
}
// if separators are preserved the fst contains a SEP_LABEL
// behind each gap. To have a matching automaton, we need to
// include the SEP_LABEL in the query as well
Automaton optionalSepLabel = Operations.optional(Automata.makeChar(ConcatenateGraphFilter.SEP_LABEL));
Automaton prefixAutomaton = Operations.concatenate(optionalSepLabel, innerAutomaton);
Automaton contextsAutomaton = Operations.concatenate(toContextAutomaton(contexts, matchAllContexts), prefixAutomaton);
contextsAutomaton = Operations.determinize(contextsAutomaton, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
final Map<IntsRef, Float> contextMap = new HashMap<>(contexts.size());
final TreeSet<Integer> contextLengths = new TreeSet<>();
for (Map.Entry<IntsRef, ContextMetaData> entry : contexts.entrySet()) {
ContextMetaData contextMetaData = entry.getValue();
contextMap.put(entry.getKey(), contextMetaData.boost);
contextLengths.add(entry.getKey().length);
}
int[] contextLengthArray = new int[contextLengths.size()];
final Iterator<Integer> iterator = contextLengths.descendingIterator();
for (int i = 0; iterator.hasNext(); i++) {
contextLengthArray[i] = iterator.next();
}
return new ContextCompletionWeight(this, contextsAutomaton, innerWeight, contextMap, contextLengthArray);
}
private static Automaton toContextAutomaton(final Map<IntsRef, ContextMetaData> contexts, final boolean matchAllContexts) {
final Automaton matchAllAutomaton = Operations.repeat(Automata.makeAnyString());
final Automaton sep = Automata.makeChar(ContextSuggestField.CONTEXT_SEPARATOR);
if (matchAllContexts || contexts.size() == 0) {
return Operations.concatenate(matchAllAutomaton, sep);
} else {
Automaton contextsAutomaton = null;
for (Map.Entry<IntsRef, ContextMetaData> entry : contexts.entrySet()) {
final ContextMetaData contextMetaData = entry.getValue();
final IntsRef ref = entry.getKey();
Automaton contextAutomaton = Automata.makeString(ref.ints, ref.offset, ref.length);
if (contextMetaData.exact == false) {
contextAutomaton = Operations.concatenate(contextAutomaton, matchAllAutomaton);
}
contextAutomaton = Operations.concatenate(contextAutomaton, sep);
if (contextsAutomaton == null) {
contextsAutomaton = contextAutomaton;
} else {
contextsAutomaton = Operations.union(contextsAutomaton, contextAutomaton);
}
}
return contextsAutomaton;
}
}
/**
* Holder for context value meta data
*/
private static class ContextMetaData {
/**
* Boost associated with a
* context value
*/
private final float boost;
/**
* flag to indicate whether the context
* value should be treated as an exact
* value or a context prefix
*/
private final boolean exact;
private ContextMetaData(float boost, boolean exact) {
this.boost = boost;
this.exact = exact;
}
}
private static class ContextCompletionWeight extends CompletionWeight {
private final Map<IntsRef, Float> contextMap;
private final int[] contextLengths;
private final CompletionWeight innerWeight;
private final BytesRefBuilder scratch = new BytesRefBuilder();
private float currentBoost;
private CharSequence currentContext;
public ContextCompletionWeight(CompletionQuery query, Automaton automaton, CompletionWeight innerWeight,
Map<IntsRef, Float> contextMap,
int[] contextLengths) throws IOException {
super(query, automaton);
this.contextMap = contextMap;
this.contextLengths = contextLengths;
this.innerWeight = innerWeight;
}
@Override
protected void setNextMatch(final IntsRef pathPrefix) {
IntsRef ref = pathPrefix.clone();
// check if the pathPrefix matches any
// defined context, longer context first
for (int contextLength : contextLengths) {
if (contextLength > pathPrefix.length) {
continue;
}
ref.length = contextLength;
if (contextMap.containsKey(ref)) {
currentBoost = contextMap.get(ref);
ref.length = pathPrefix.length;
setInnerWeight(ref, contextLength);
return;
}
}
// unknown context
ref.length = pathPrefix.length;
currentBoost = 0f;
setInnerWeight(ref, 0);
}
private void setInnerWeight(IntsRef ref, int offset) {
IntsRefBuilder refBuilder = new IntsRefBuilder();
for (int i = offset; i < ref.length; i++) {
if (ref.ints[ref.offset + i] == ContextSuggestField.CONTEXT_SEPARATOR) {
if (i > 0) {
refBuilder.copyInts(ref.ints, ref.offset, i);
currentContext = Util.toBytesRef(refBuilder.get(), scratch).utf8ToString();
} else {
currentContext = null;
}
ref.offset = ++i;
assert ref.offset < ref.length : "input should not end with the context separator";
if (ref.ints[i] == ConcatenateGraphFilter.SEP_LABEL) {
ref.offset++;
assert ref.offset < ref.length : "input should not end with a context separator followed by SEP_LABEL";
}
ref.length = ref.length - ref.offset;
refBuilder.copyInts(ref.ints, ref.offset, ref.length);
innerWeight.setNextMatch(refBuilder.get());
return;
}
}
}
@Override
protected CharSequence context() {
return currentContext;
}
@Override
protected float boost() {
return currentBoost + innerWeight.boost();
}
}
@Override
public boolean equals(Object o) {
throw new UnsupportedOperationException();
}
@Override
public int hashCode() {
throw new UnsupportedOperationException();
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public long ramBytesUsed() {
return ramBytesUsed;
}
}