blob: e3b5deae2fdad2d31d6db6c90c30fef9477789ef [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.PrefixCodedTerms;
import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.ByteRunAutomaton;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.apache.lucene.util.automaton.Operations;
/**
* Specialization for a disjunction over many terms that behaves like a
* {@link ConstantScoreQuery} over a {@link BooleanQuery} containing only
* {@link org.apache.lucene.search.BooleanClause.Occur#SHOULD} clauses.
* <p>For instance in the following example, both {@code q1} and {@code q2}
* would yield the same scores:
* <pre class="prettyprint">
* Query q1 = new TermInSetQuery("field", new BytesRef("foo"), new BytesRef("bar"));
*
* BooleanQuery bq = new BooleanQuery();
* bq.add(new TermQuery(new Term("field", "foo")), Occur.SHOULD);
* bq.add(new TermQuery(new Term("field", "bar")), Occur.SHOULD);
* Query q2 = new ConstantScoreQuery(bq);
* </pre>
* <p>When there are few terms, this query executes like a regular disjunction.
* However, when there are many terms, instead of merging iterators on the fly,
* it will populate a bit set with matching docs and return a {@link Scorer}
* over this bit set.
* <p>NOTE: This query produces scores that are equal to its boost
*/
public class TermInSetQuery extends Query implements Accountable {
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(TermInSetQuery.class);
// Same threshold as MultiTermQueryConstantScoreWrapper
static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;
private final String field;
private final PrefixCodedTerms termData;
private final int termDataHashCode; // cached hashcode of termData
/**
* Creates a new {@link TermInSetQuery} from the given collection of terms.
*/
public TermInSetQuery(String field, Collection<BytesRef> terms) {
BytesRef[] sortedTerms = terms.toArray(new BytesRef[0]);
// already sorted if we are a SortedSet with natural order
boolean sorted = terms instanceof SortedSet && ((SortedSet<BytesRef>)terms).comparator() == null;
if (!sorted) {
ArrayUtil.timSort(sortedTerms);
}
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
BytesRefBuilder previous = null;
for (BytesRef term : sortedTerms) {
if (previous == null) {
previous = new BytesRefBuilder();
} else if (previous.get().equals(term)) {
continue; // deduplicate
}
builder.add(field, term);
previous.copyBytes(term);
}
this.field = field;
termData = builder.finish();
termDataHashCode = termData.hashCode();
}
/**
* Creates a new {@link TermInSetQuery} from the given array of terms.
*/
public TermInSetQuery(String field, BytesRef...terms) {
this(field, Arrays.asList(terms));
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount());
if (termData.size() <= threshold) {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
TermIterator iterator = termData.iterator();
for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
bq.add(new TermQuery(new Term(iterator.field(), BytesRef.deepCopyOf(term))), Occur.SHOULD);
}
return new ConstantScoreQuery(bq.build());
}
return super.rewrite(reader);
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field) == false) {
return;
}
if (termData.size() == 1) {
visitor.consumeTerms(this, new Term(field, termData.iterator().next()));
}
if (termData.size() > 1) {
visitor.consumeTermsMatching(this, field, this::asByteRunAutomaton);
}
}
private ByteRunAutomaton asByteRunAutomaton() {
TermIterator iterator = termData.iterator();
List<Automaton> automata = new ArrayList<>();
for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
automata.add(Automata.makeBinary(term));
}
return new CompiledAutomaton(Operations.union(automata)).runAutomaton;
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
}
private boolean equalsTo(TermInSetQuery other) {
// no need to check 'field' explicitly since it is encoded in 'termData'
// termData might be heavy to compare so check the hash code first
return termDataHashCode == other.termDataHashCode &&
termData.equals(other.termData);
}
@Override
public int hashCode() {
return 31 * classHash() + termDataHashCode;
}
/** Returns the terms wrapped in a PrefixCodedTerms. */
public PrefixCodedTerms getTermData() {
return termData;
}
@Override
public String toString(String defaultField) {
StringBuilder builder = new StringBuilder();
builder.append(field);
builder.append(":(");
TermIterator iterator = termData.iterator();
boolean first = true;
for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
if (!first) {
builder.append(' ');
}
first = false;
builder.append(Term.toString(term));
}
builder.append(')');
return builder.toString();
}
@Override
public long ramBytesUsed() {
return BASE_RAM_BYTES_USED + termData.ramBytesUsed();
}
@Override
public Collection<Accountable> getChildResources() {
return Collections.emptyList();
}
private static class TermAndState {
final String field;
final TermsEnum termsEnum;
final BytesRef term;
final TermState state;
final int docFreq;
final long totalTermFreq;
TermAndState(String field, TermsEnum termsEnum) throws IOException {
this.field = field;
this.termsEnum = termsEnum;
this.term = BytesRef.deepCopyOf(termsEnum.term());
this.state = termsEnum.termState();
this.docFreq = termsEnum.docFreq();
this.totalTermFreq = termsEnum.totalTermFreq();
}
}
private static class WeightOrDocIdSet {
final Weight weight;
final DocIdSet set;
WeightOrDocIdSet(Weight weight) {
this.weight = Objects.requireNonNull(weight);
this.set = null;
}
WeightOrDocIdSet(DocIdSet bitset) {
this.set = bitset;
this.weight = null;
}
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public void extractTerms(Set<Term> terms) {
// no-op
// This query is for abuse cases when the number of terms is too high to
// run efficiently as a BooleanQuery. So likewise we hide its terms in
// order to protect highlighters
}
@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
Terms terms = context.reader().terms(field);
if (terms == null || terms.hasPositions() == false) {
return super.matches(context, doc);
}
return MatchesUtils.forField(field, () -> DisjunctionMatchesIterator.fromTermsEnum(context, doc, getQuery(), field, termData.iterator()));
}
/**
* On the given leaf context, try to either rewrite to a disjunction if
* there are few matching terms, or build a bitset containing matching docs.
*/
private WeightOrDocIdSet rewrite(LeafReaderContext context) throws IOException {
final LeafReader reader = context.reader();
Terms terms = reader.terms(field);
if (terms == null) {
return null;
}
TermsEnum termsEnum = terms.iterator();
PostingsEnum docs = null;
TermIterator iterator = termData.iterator();
// We will first try to collect up to 'threshold' terms into 'matchingTerms'
// if there are two many terms, we will fall back to building the 'builder'
final int threshold = Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, BooleanQuery.getMaxClauseCount());
assert termData.size() > threshold : "Query should have been rewritten";
List<TermAndState> matchingTerms = new ArrayList<>(threshold);
DocIdSetBuilder builder = null;
for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
assert field.equals(iterator.field());
if (termsEnum.seekExact(term)) {
if (matchingTerms == null) {
docs = termsEnum.postings(docs, PostingsEnum.NONE);
builder.add(docs);
} else if (matchingTerms.size() < threshold) {
matchingTerms.add(new TermAndState(field, termsEnum));
} else {
assert matchingTerms.size() == threshold;
builder = new DocIdSetBuilder(reader.maxDoc(), terms);
docs = termsEnum.postings(docs, PostingsEnum.NONE);
builder.add(docs);
for (TermAndState t : matchingTerms) {
t.termsEnum.seekExact(t.term, t.state);
docs = t.termsEnum.postings(docs, PostingsEnum.NONE);
builder.add(docs);
}
matchingTerms = null;
}
}
}
if (matchingTerms != null) {
assert builder == null;
BooleanQuery.Builder bq = new BooleanQuery.Builder();
for (TermAndState t : matchingTerms) {
final TermStates termStates = new TermStates(searcher.getTopReaderContext());
termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq);
bq.add(new TermQuery(new Term(t.field, t.term), termStates), Occur.SHOULD);
}
Query q = new ConstantScoreQuery(bq.build());
final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
return new WeightOrDocIdSet(weight);
} else {
assert builder != null;
return new WeightOrDocIdSet(builder.build());
}
}
private Scorer scorer(DocIdSet set) throws IOException {
if (set == null) {
return null;
}
final DocIdSetIterator disi = set.iterator();
if (disi == null) {
return null;
}
return new ConstantScoreScorer(this, score(), scoreMode, disi);
}
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
final WeightOrDocIdSet weightOrBitSet = rewrite(context);
if (weightOrBitSet == null) {
return null;
} else if (weightOrBitSet.weight != null) {
return weightOrBitSet.weight.bulkScorer(context);
} else {
final Scorer scorer = scorer(weightOrBitSet.set);
if (scorer == null) {
return null;
}
return new DefaultBulkScorer(scorer);
}
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final WeightOrDocIdSet weightOrBitSet = rewrite(context);
if (weightOrBitSet == null) {
return null;
} else if (weightOrBitSet.weight != null) {
return weightOrBitSet.weight.scorer(context);
} else {
return scorer(weightOrBitSet.set);
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
// Only cache instances that have a reasonable size. Otherwise it might cause memory issues
// with the query cache if most memory ends up being spent on queries rather than doc id sets.
return ramBytesUsed() <= RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED;
}
};
}
}