blob: 5c361cf7e9e6135ba80e7569aeb75c7fe7c9ab9c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;
/**
* A query that executes high-frequency terms in a optional sub-query to prevent
* slow queries due to "common" terms like stopwords. This query
* builds 2 queries off the {@link #add(Term) added} terms: low-frequency
* terms are added to a required boolean clause and high-frequency terms are
* added to an optional boolean clause. The optional clause is only executed if
* the required "low-frequency" clause matches. In most cases, high-frequency terms are
* unlikely to significantly contribute to the document score unless at least
* one of the low-frequency terms are matched. This query can improve
* query execution times significantly if applicable.
* <p>
* {@link CommonTermsQuery} has several advantages over stopword filtering at
* index or query time since a term can be "classified" based on the actual
* document frequency in the index and can prevent slow queries even across
* domains without specialized stopword files.
* </p>
* <p>
* <b>Note:</b> if the query only contains high-frequency terms the query is
* rewritten into a plain conjunction query ie. all high-frequency terms need to
* match in order to match a document.
* </p>
*/
public class CommonTermsQuery extends Query {
/*
* TODO maybe it would make sense to abstract this even further and allow to
* rewrite to dismax rather than boolean. Yet, this can already be subclassed
* to do so.
*/
protected final List<Term> terms = new ArrayList<>();
protected final float maxTermFrequency;
protected final Occur lowFreqOccur;
protected final Occur highFreqOccur;
protected float lowFreqBoost = 1.0f;
protected float highFreqBoost = 1.0f;
protected float lowFreqMinNrShouldMatch = 0;
protected float highFreqMinNrShouldMatch = 0;
/**
* Creates a new {@link CommonTermsQuery}
*
* @param highFreqOccur
* {@link Occur} used for high frequency terms
* @param lowFreqOccur
* {@link Occur} used for low frequency terms
* @param maxTermFrequency
* a value in [0..1) (or absolute number &gt;=1) representing the
* maximum threshold of a terms document frequency to be considered a
* low frequency term.
* @throws IllegalArgumentException
* if {@link Occur#MUST_NOT} is pass as lowFreqOccur or
* highFreqOccur
*/
public CommonTermsQuery(Occur highFreqOccur, Occur lowFreqOccur,
float maxTermFrequency) {
if (highFreqOccur == Occur.MUST_NOT) {
throw new IllegalArgumentException(
"highFreqOccur should be MUST or SHOULD but was MUST_NOT");
}
if (lowFreqOccur == Occur.MUST_NOT) {
throw new IllegalArgumentException(
"lowFreqOccur should be MUST or SHOULD but was MUST_NOT");
}
this.highFreqOccur = highFreqOccur;
this.lowFreqOccur = lowFreqOccur;
this.maxTermFrequency = maxTermFrequency;
}
/**
* Adds a term to the {@link CommonTermsQuery}
*
* @param term
* the term to add
*/
public void add(Term term) {
if (term == null) {
throw new IllegalArgumentException("Term must not be null");
}
this.terms.add(term);
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (this.terms.isEmpty()) {
return new MatchNoDocsQuery("CommonTermsQuery with no terms");
} else if (this.terms.size() == 1) {
return newTermQuery(this.terms.get(0), null);
}
final List<LeafReaderContext> leaves = reader.leaves();
final int maxDoc = reader.maxDoc();
final TermStates[] contextArray = new TermStates[terms.size()];
final Term[] queryTerms = this.terms.toArray(new Term[0]);
collectTermStates(reader, leaves, contextArray, queryTerms);
return buildQuery(maxDoc, contextArray, queryTerms);
}
@Override
public void visit(QueryVisitor visitor) {
Term[] selectedTerms = terms.stream().filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new);
if (selectedTerms.length > 0) {
QueryVisitor v = visitor.getSubVisitor(Occur.SHOULD, this);
v.consumeTerms(this, selectedTerms);
}
}
protected int calcLowFreqMinimumNumberShouldMatch(int numOptional) {
return minNrShouldMatch(lowFreqMinNrShouldMatch, numOptional);
}
protected int calcHighFreqMinimumNumberShouldMatch(int numOptional) {
return minNrShouldMatch(highFreqMinNrShouldMatch, numOptional);
}
private final int minNrShouldMatch(float minNrShouldMatch, int numOptional) {
if (minNrShouldMatch >= 1.0f || minNrShouldMatch == 0.0f) {
return (int) minNrShouldMatch;
}
return Math.round(minNrShouldMatch * numOptional);
}
protected Query buildQuery(final int maxDoc,
final TermStates[] contextArray, final Term[] queryTerms) {
List<Query> lowFreqQueries = new ArrayList<>();
List<Query> highFreqQueries = new ArrayList<>();
for (int i = 0; i < queryTerms.length; i++) {
TermStates termStates = contextArray[i];
if (termStates == null) {
lowFreqQueries.add(newTermQuery(queryTerms[i], null));
} else {
if ((maxTermFrequency >= 1f && termStates.docFreq() > maxTermFrequency)
|| (termStates.docFreq() > (int) Math.ceil(maxTermFrequency
* (float) maxDoc))) {
highFreqQueries
.add(newTermQuery(queryTerms[i], termStates));
} else {
lowFreqQueries.add(newTermQuery(queryTerms[i], termStates));
}
}
}
final int numLowFreqClauses = lowFreqQueries.size();
final int numHighFreqClauses = highFreqQueries.size();
Occur lowFreqOccur = this.lowFreqOccur;
Occur highFreqOccur = this.highFreqOccur;
int lowFreqMinShouldMatch = 0;
int highFreqMinShouldMatch = 0;
if (lowFreqOccur == Occur.SHOULD && numLowFreqClauses > 0) {
lowFreqMinShouldMatch = calcLowFreqMinimumNumberShouldMatch(numLowFreqClauses);
}
if (highFreqOccur == Occur.SHOULD && numHighFreqClauses > 0) {
highFreqMinShouldMatch = calcHighFreqMinimumNumberShouldMatch(numHighFreqClauses);
}
if (lowFreqQueries.isEmpty()) {
/*
* if lowFreq is empty we rewrite the high freq terms in a conjunction to
* prevent slow queries.
*/
if (highFreqMinShouldMatch == 0 && highFreqOccur != Occur.MUST) {
highFreqOccur = Occur.MUST;
}
}
BooleanQuery.Builder builder = new BooleanQuery.Builder();
if (lowFreqQueries.isEmpty() == false) {
BooleanQuery.Builder lowFreq = new BooleanQuery.Builder();
for (Query query : lowFreqQueries) {
lowFreq.add(query, lowFreqOccur);
}
lowFreq.setMinimumNumberShouldMatch(lowFreqMinShouldMatch);
Query lowFreqQuery = lowFreq.build();
builder.add(new BoostQuery(lowFreqQuery, lowFreqBoost), Occur.MUST);
}
if (highFreqQueries.isEmpty() == false) {
BooleanQuery.Builder highFreq = new BooleanQuery.Builder();
for (Query query : highFreqQueries) {
highFreq.add(query, highFreqOccur);
}
highFreq.setMinimumNumberShouldMatch(highFreqMinShouldMatch);
Query highFreqQuery = highFreq.build();
builder.add(new BoostQuery(highFreqQuery, highFreqBoost), Occur.SHOULD);
}
return builder.build();
}
public void collectTermStates(IndexReader reader,
List<LeafReaderContext> leaves, TermStates[] contextArray,
Term[] queryTerms) throws IOException {
TermsEnum termsEnum = null;
for (LeafReaderContext context : leaves) {
for (int i = 0; i < queryTerms.length; i++) {
Term term = queryTerms[i];
TermStates termStates = contextArray[i];
final Terms terms = context.reader().terms(term.field());
if (terms == null) {
// field does not exist
continue;
}
termsEnum = terms.iterator();
assert termsEnum != null;
if (termsEnum == TermsEnum.EMPTY) continue;
if (termsEnum.seekExact(term.bytes())) {
if (termStates == null) {
contextArray[i] = new TermStates(reader.getContext(),
termsEnum.termState(), context.ord, termsEnum.docFreq(),
termsEnum.totalTermFreq());
} else {
termStates.register(termsEnum.termState(), context.ord,
termsEnum.docFreq(), termsEnum.totalTermFreq());
}
}
}
}
}
/**
* Specifies a minimum number of the low frequent optional BooleanClauses which must be
* satisfied in order to produce a match on the low frequency terms query
* part. This method accepts a float value in the range [0..1) as a fraction
* of the actual query terms in the low frequent clause or a number
* <tt>&gt;=1</tt> as an absolut number of clauses that need to match.
*
* <p>
* By default no optional clauses are necessary for a match (unless there are
* no required clauses). If this method is used, then the specified number of
* clauses is required.
* </p>
*
* @param min
* the number of optional clauses that must match
*/
public void setLowFreqMinimumNumberShouldMatch(float min) {
this.lowFreqMinNrShouldMatch = min;
}
/**
* Gets the minimum number of the optional low frequent BooleanClauses which must be
* satisfied.
*/
public float getLowFreqMinimumNumberShouldMatch() {
return lowFreqMinNrShouldMatch;
}
/**
* Specifies a minimum number of the high frequent optional BooleanClauses which must be
* satisfied in order to produce a match on the low frequency terms query
* part. This method accepts a float value in the range [0..1) as a fraction
* of the actual query terms in the low frequent clause or a number
* <tt>&gt;=1</tt> as an absolut number of clauses that need to match.
*
* <p>
* By default no optional clauses are necessary for a match (unless there are
* no required clauses). If this method is used, then the specified number of
* clauses is required.
* </p>
*
* @param min
* the number of optional clauses that must match
*/
public void setHighFreqMinimumNumberShouldMatch(float min) {
this.highFreqMinNrShouldMatch = min;
}
/**
* Gets the minimum number of the optional high frequent BooleanClauses which must be
* satisfied.
*/
public float getHighFreqMinimumNumberShouldMatch() {
return highFreqMinNrShouldMatch;
}
/**
* Gets the list of terms.
*/
public List<Term> getTerms() {
return Collections.unmodifiableList(terms);
}
/**
* Gets the maximum threshold of a terms document frequency to be considered a
* low frequency term.
*/
public float getMaxTermFrequency() {
return maxTermFrequency;
}
/**
* Gets the {@link Occur} used for low frequency terms.
*/
public Occur getLowFreqOccur() {
return lowFreqOccur;
}
/**
* Gets the {@link Occur} used for high frequency terms.
*/
public Occur getHighFreqOccur() {
return highFreqOccur;
}
/**
* Gets the boost used for low frequency terms.
*/
public float getLowFreqBoost() {
return lowFreqBoost;
}
/**
* Gets the boost used for high frequency terms.
*/
public float getHighFreqBoost() {
return highFreqBoost;
}
@Override
public String toString(String field) {
StringBuilder buffer = new StringBuilder();
boolean needParens = (getLowFreqMinimumNumberShouldMatch() > 0);
if (needParens) {
buffer.append("(");
}
for (int i = 0; i < terms.size(); i++) {
Term t = terms.get(i);
buffer.append(newTermQuery(t, null).toString());
if (i != terms.size() - 1) buffer.append(", ");
}
if (needParens) {
buffer.append(")");
}
if (getLowFreqMinimumNumberShouldMatch() > 0 || getHighFreqMinimumNumberShouldMatch() > 0) {
buffer.append('~');
buffer.append("(");
buffer.append(getLowFreqMinimumNumberShouldMatch());
buffer.append(getHighFreqMinimumNumberShouldMatch());
buffer.append(")");
}
return buffer.toString();
}
@Override
public int hashCode() {
final int prime = 31;
int result = classHash();
result = prime * result + Float.floatToIntBits(highFreqBoost);
result = prime * result + Objects.hashCode(highFreqOccur);
result = prime * result + Objects.hashCode(lowFreqOccur);
result = prime * result + Float.floatToIntBits(lowFreqBoost);
result = prime * result + Float.floatToIntBits(maxTermFrequency);
result = prime * result + Float.floatToIntBits(lowFreqMinNrShouldMatch);
result = prime * result + Float.floatToIntBits(highFreqMinNrShouldMatch);
result = prime * result + Objects.hashCode(terms);
return result;
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(CommonTermsQuery other) {
return Float.floatToIntBits(highFreqBoost) == Float.floatToIntBits(other.highFreqBoost) &&
highFreqOccur == other.highFreqOccur &&
lowFreqOccur == other.lowFreqOccur &&
Float.floatToIntBits(lowFreqBoost) == Float.floatToIntBits(other.lowFreqBoost) &&
Float.floatToIntBits(maxTermFrequency) == Float.floatToIntBits(other.maxTermFrequency) &&
lowFreqMinNrShouldMatch == other.lowFreqMinNrShouldMatch &&
highFreqMinNrShouldMatch == other.highFreqMinNrShouldMatch &&
terms.equals(other.terms);
}
/**
* Builds a new TermQuery instance.
* <p>This is intended for subclasses that wish to customize the generated queries.</p>
* @param term term
* @param termStates the TermStates to be used to create the low level term query. Can be <code>null</code>.
* @return new TermQuery instance
*/
protected Query newTermQuery(Term term, TermStates termStates) {
return termStates == null ? new TermQuery(term) : new TermQuery(term, termStates);
}
}