blob: 00ba5a13fd73e1941fbdb1142b0b58725bf68e2e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.InPlaceMergeSorter;
/**
* A {@link Query} that blends index statistics across multiple terms.
* This is particularly useful when several terms should produce identical
* scores, regardless of their index statistics.
* <p>For instance imagine that you are resolving synonyms at search time,
* all terms should produce identical scores instead of the default behavior,
* which tends to give higher scores to rare terms.
* <p>An other useful use-case is cross-field search: imagine that you would
* like to search for {@code john} on two fields: {@code first_name} and
* {@code last_name}. You might not want to give a higher weight to matches
* on the field where {@code john} is rarer, in which case
* {@link BlendedTermQuery} would help as well.
* @lucene.experimental
*/
public final class BlendedTermQuery extends Query {
/** A Builder for {@link BlendedTermQuery}. */
public static class Builder {
private int numTerms = 0;
private Term[] terms = new Term[0];
private float[] boosts = new float[0];
private TermStates[] contexts = new TermStates[0];
private RewriteMethod rewriteMethod = DISJUNCTION_MAX_REWRITE;
/** Sole constructor. */
public Builder() {}
/** Set the {@link RewriteMethod}. Default is to use
* {@link BlendedTermQuery#DISJUNCTION_MAX_REWRITE}.
* @see RewriteMethod */
public Builder setRewriteMethod(RewriteMethod rewiteMethod) {
this.rewriteMethod = rewiteMethod;
return this;
}
/** Add a new {@link Term} to this builder, with a default boost of {@code 1}.
* @see #add(Term, float) */
public Builder add(Term term) {
return add(term, 1f);
}
/** Add a {@link Term} with the provided boost. The higher the boost, the
* more this term will contribute to the overall score of the
* {@link BlendedTermQuery}. */
public Builder add(Term term, float boost) {
return add(term, boost, null);
}
/**
* Expert: Add a {@link Term} with the provided boost and context.
* This method is useful if you already have a {@link TermStates}
* object constructed for the given term.
*/
public Builder add(Term term, float boost, TermStates context) {
if (numTerms >= BooleanQuery.getMaxClauseCount()) {
throw new BooleanQuery.TooManyClauses();
}
terms = ArrayUtil.grow(terms, numTerms + 1);
boosts = ArrayUtil.grow(boosts, numTerms + 1);
contexts = ArrayUtil.grow(contexts, numTerms + 1);
terms[numTerms] = term;
boosts[numTerms] = boost;
contexts[numTerms] = context;
numTerms += 1;
return this;
}
/** Build the {@link BlendedTermQuery}. */
public BlendedTermQuery build() {
return new BlendedTermQuery(
ArrayUtil.copyOfSubArray(terms, 0, numTerms),
ArrayUtil.copyOfSubArray(boosts, 0, numTerms),
ArrayUtil.copyOfSubArray(contexts, 0, numTerms),
rewriteMethod);
}
}
/** A {@link RewriteMethod} defines how queries for individual terms should
* be merged.
* @lucene.experimental
* @see BlendedTermQuery#BOOLEAN_REWRITE
* @see BlendedTermQuery.DisjunctionMaxRewrite */
public static abstract class RewriteMethod {
/** Sole constructor */
protected RewriteMethod() {}
/** Merge the provided sub queries into a single {@link Query} object. */
public abstract Query rewrite(Query[] subQueries);
}
/**
* A {@link RewriteMethod} that adds all sub queries to a {@link BooleanQuery}.
* This {@link RewriteMethod} is useful when matching on several fields is
* considered better than having a good match on a single field.
*/
public static final RewriteMethod BOOLEAN_REWRITE = new RewriteMethod() {
@Override
public Query rewrite(Query[] subQueries) {
BooleanQuery.Builder merged = new BooleanQuery.Builder();
for (Query query : subQueries) {
merged.add(query, Occur.SHOULD);
}
return merged.build();
}
};
/**
* A {@link RewriteMethod} that creates a {@link DisjunctionMaxQuery} out
* of the sub queries. This {@link RewriteMethod} is useful when having a
* good match on a single field is considered better than having average
* matches on several fields.
*/
public static class DisjunctionMaxRewrite extends RewriteMethod {
private final float tieBreakerMultiplier;
/** This {@link RewriteMethod} will create {@link DisjunctionMaxQuery}
* instances that have the provided tie breaker.
* @see DisjunctionMaxQuery */
public DisjunctionMaxRewrite(float tieBreakerMultiplier) {
this.tieBreakerMultiplier = tieBreakerMultiplier;
}
@Override
public Query rewrite(Query[] subQueries) {
return new DisjunctionMaxQuery(Arrays.asList(subQueries), tieBreakerMultiplier);
}
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DisjunctionMaxRewrite that = (DisjunctionMaxRewrite) obj;
return tieBreakerMultiplier == that.tieBreakerMultiplier;
}
@Override
public int hashCode() {
return 31 * getClass().hashCode() + Float.floatToIntBits(tieBreakerMultiplier);
}
}
/** {@link DisjunctionMaxRewrite} instance with a tie-breaker of {@code 0.01}. */
public static final RewriteMethod DISJUNCTION_MAX_REWRITE = new DisjunctionMaxRewrite(0.01f);
private final Term[] terms;
private final float[] boosts;
private final TermStates[] contexts;
private final RewriteMethod rewriteMethod;
private BlendedTermQuery(Term[] terms, float[] boosts, TermStates[] contexts,
RewriteMethod rewriteMethod) {
assert terms.length == boosts.length;
assert terms.length == contexts.length;
this.terms = terms;
this.boosts = boosts;
this.contexts = contexts;
this.rewriteMethod = rewriteMethod;
// we sort terms so that equals/hashcode does not rely on the order
new InPlaceMergeSorter() {
@Override
protected void swap(int i, int j) {
Term tmpTerm = terms[i];
terms[i] = terms[j];
terms[j] = tmpTerm;
TermStates tmpContext = contexts[i];
contexts[i] = contexts[j];
contexts[j] = tmpContext;
float tmpBoost = boosts[i];
boosts[i] = boosts[j];
boosts[j] = tmpBoost;
}
@Override
protected int compare(int i, int j) {
return terms[i].compareTo(terms[j]);
}
}.sort(0, terms.length);
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(BlendedTermQuery other) {
return Arrays.equals(terms, other.terms) &&
Arrays.equals(contexts, other.contexts) &&
Arrays.equals(boosts, other.boosts) &&
rewriteMethod.equals(other.rewriteMethod);
}
@Override
public int hashCode() {
int h = classHash();
h = 31 * h + Arrays.hashCode(terms);
h = 31 * h + Arrays.hashCode(contexts);
h = 31 * h + Arrays.hashCode(boosts);
h = 31 * h + rewriteMethod.hashCode();
return h;
}
@Override
public String toString(String field) {
StringBuilder builder = new StringBuilder("Blended(");
for (int i = 0; i < terms.length; ++i) {
if (i != 0) {
builder.append(" ");
}
Query termQuery = new TermQuery(terms[i]);
if (boosts[i] != 1f) {
termQuery = new BoostQuery(termQuery, boosts[i]);
}
builder.append(termQuery.toString(field));
}
builder.append(")");
return builder.toString();
}
@Override
public final Query rewrite(IndexReader reader) throws IOException {
final TermStates[] contexts = ArrayUtil.copyOfSubArray(this.contexts, 0, this.contexts.length);
for (int i = 0; i < contexts.length; ++i) {
if (contexts[i] == null || contexts[i].wasBuiltFor(reader.getContext()) == false) {
contexts[i] = TermStates.build(reader.getContext(), terms[i], true);
}
}
// Compute aggregated doc freq and total term freq
// df will be the max of all doc freqs
// ttf will be the sum of all total term freqs
int df = 0;
long ttf = 0;
for (TermStates ctx : contexts) {
df = Math.max(df, ctx.docFreq());
ttf += ctx.totalTermFreq();
}
for (int i = 0; i < contexts.length; ++i) {
contexts[i] = adjustFrequencies(reader.getContext(), contexts[i], df, ttf);
}
Query[] termQueries = new Query[terms.length];
for (int i = 0; i < terms.length; ++i) {
termQueries[i] = new TermQuery(terms[i], contexts[i]);
if (boosts[i] != 1f) {
termQueries[i] = new BoostQuery(termQueries[i], boosts[i]);
}
}
return rewriteMethod.rewrite(termQueries);
}
@Override
public void visit(QueryVisitor visitor) {
Term[] termsToVisit = Arrays.stream(terms).filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new);
if (termsToVisit.length > 0) {
QueryVisitor v = visitor.getSubVisitor(Occur.SHOULD, this);
v.consumeTerms(this, termsToVisit);
}
}
private static TermStates adjustFrequencies(IndexReaderContext readerContext,
TermStates ctx, int artificialDf, long artificialTtf) throws IOException {
List<LeafReaderContext> leaves = readerContext.leaves();
final int len;
if (leaves == null) {
len = 1;
} else {
len = leaves.size();
}
TermStates newCtx = new TermStates(readerContext);
for (int i = 0; i < len; ++i) {
TermState termState = ctx.get(leaves.get(i));
if (termState == null) {
continue;
}
newCtx.register(termState, i);
}
newCtx.accumulateStatistics(artificialDf, artificialTtf);
return newCtx;
}
}