blob: 763c8bc25b166099d4e8f12470f0d4f2b483c668 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.index.Impact;
import org.apache.lucene.index.Impacts;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.ImpactsSource;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
/**
* A query that treats multiple terms as synonyms.
* <p>
* For scoring purposes, this query tries to score the terms as if you
* had indexed them as one term: it will match any of the terms but
* only invoke the similarity a single time, scoring the sum of all
* term frequencies for the document.
*/
public final class SynonymQuery extends Query {
private final TermAndBoost terms[];
private final String field;
/**
* A builder for {@link SynonymQuery}.
*/
public static class Builder {
private final String field;
private final List<TermAndBoost> terms = new ArrayList<>();
/**
* Sole constructor
*
* @param field The target field name
*/
public Builder(String field) {
this.field = field;
}
/**
* Adds the provided {@code term} as a synonym.
*/
public Builder addTerm(Term term) {
return addTerm(term, 1f);
}
/**
* Adds the provided {@code term} as a synonym, document frequencies of this term
* will be boosted by {@code boost}.
*/
public Builder addTerm(Term term, float boost) {
if (field.equals(term.field()) == false) {
throw new IllegalArgumentException("Synonyms must be across the same field");
}
if (Float.isNaN(boost) || Float.compare(boost, 0f) <= 0 || Float.compare(boost, 1f) > 0) {
throw new IllegalArgumentException("boost must be a positive float between 0 (exclusive) and 1 (inclusive)");
}
terms.add(new TermAndBoost(term, boost));
if (terms.size() > BooleanQuery.getMaxClauseCount()) {
throw new BooleanQuery.TooManyClauses();
}
return this;
}
/**
* Builds the {@link SynonymQuery}.
*/
public SynonymQuery build() {
Collections.sort(terms);
return new SynonymQuery(terms.toArray(new TermAndBoost[0]), field);
}
}
/**
* Creates a new SynonymQuery, matching any of the supplied terms.
* <p>
* The terms must all have the same field.
*
* @deprecated Please use a {@link Builder} instead.
*/
@Deprecated
public SynonymQuery(Term... terms) {
Objects.requireNonNull(terms);
if (terms.length > BooleanQuery.getMaxClauseCount()) {
throw new BooleanQuery.TooManyClauses();
}
this.terms = new TermAndBoost[terms.length];
// check that all terms are the same field
String field = null;
for (int i = 0; i < terms.length; i++) {
Term term = terms[i];
this.terms[i] = new TermAndBoost(term, 1.0f);
if (field == null) {
field = term.field();
} else if (!term.field().equals(field)) {
throw new IllegalArgumentException("Synonyms must be across the same field");
}
}
Arrays.sort(this.terms);
this.field = field;
}
/**
* Creates a new SynonymQuery, matching any of the supplied terms.
* <p>
* The terms must all have the same field.
*/
private SynonymQuery(TermAndBoost[] terms, String field) {
this.terms = Objects.requireNonNull(terms);
this.field = field;
}
public List<Term> getTerms() {
return Collections.unmodifiableList(
Arrays.stream(terms)
.map(TermAndBoost::getTerm)
.collect(Collectors.toList())
);
}
@Override
public String toString(String field) {
StringBuilder builder = new StringBuilder("Synonym(");
for (int i = 0; i < terms.length; i++) {
if (i != 0) {
builder.append(" ");
}
Query termQuery = new TermQuery(terms[i].term);
builder.append(termQuery.toString(field));
if (terms[i].boost != 1f) {
builder.append("^");
builder.append(terms[i].boost);
}
}
builder.append(")");
return builder.toString();
}
@Override
public int hashCode() {
return 31 * classHash() + Arrays.hashCode(terms);
}
@Override
public boolean equals(Object other) {
return sameClassAs(other)
&& Arrays.equals(terms, ((SynonymQuery) other).terms);
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
// optimize zero and single term cases
if (terms.length == 0) {
return new BooleanQuery.Builder().build();
}
if (terms.length == 1) {
return terms[0].boost == 1f ? new TermQuery(terms[0].term) : new BoostQuery(new TermQuery(terms[0].term), terms[0].boost);
}
return this;
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field) == false) {
return;
}
QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this);
Term[] ts = Arrays.stream(terms).map(t -> t.term).toArray(Term[]::new);
v.consumeTerms(this, ts);
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (scoreMode.needsScores()) {
return new SynonymWeight(this, searcher, scoreMode, boost);
} else {
// if scores are not needed, let BooleanWeight deal with optimizing that case.
BooleanQuery.Builder bq = new BooleanQuery.Builder();
for (TermAndBoost term : terms) {
bq.add(new TermQuery(term.term), BooleanClause.Occur.SHOULD);
}
return searcher.rewrite(bq.build()).createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, boost);
}
}
class SynonymWeight extends Weight {
private final TermStates termStates[];
private final Similarity similarity;
private final Similarity.SimScorer simWeight;
private final ScoreMode scoreMode;
SynonymWeight(Query query, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
super(query);
assert scoreMode.needsScores();
this.scoreMode = scoreMode;
CollectionStatistics collectionStats = searcher.collectionStatistics(terms[0].term.field());
long docFreq = 0;
long totalTermFreq = 0;
termStates = new TermStates[terms.length];
for (int i = 0; i < termStates.length; i++) {
TermStates ts = TermStates.build(searcher.getTopReaderContext(), terms[i].term, true);
termStates[i] = ts;
if (ts.docFreq() > 0) {
TermStatistics termStats = searcher.termStatistics(terms[i].term, ts.docFreq(), ts.totalTermFreq());
docFreq = Math.max(termStats.docFreq(), docFreq);
totalTermFreq += termStats.totalTermFreq();
}
}
this.similarity = searcher.getSimilarity();
if (docFreq > 0) {
TermStatistics pseudoStats = new TermStatistics(new BytesRef("synonym pseudo-term"), docFreq, totalTermFreq);
this.simWeight = similarity.scorer(boost, collectionStats, pseudoStats);
} else {
this.simWeight = null; // no terms exist at all, we won't use similarity
}
}
@Override
public void extractTerms(Set<Term> terms) {
for (TermAndBoost term : SynonymQuery.this.terms) {
terms.add(term.term);
}
}
@Override
public Matches matches(LeafReaderContext context, int doc) throws IOException {
String field = terms[0].term.field();
Terms indexTerms = context.reader().terms(field);
if (indexTerms == null || indexTerms.hasPositions() == false) {
return super.matches(context, doc);
}
List<Term> termList = Arrays.stream(terms)
.map(TermAndBoost::getTerm)
.collect(Collectors.toList());
return MatchesUtils.forField(field, () -> DisjunctionMatchesIterator.fromTerms(context, doc, getQuery(), field, termList));
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Scorer scorer = scorer(context);
if (scorer != null) {
int newDoc = scorer.iterator().advance(doc);
if (newDoc == doc) {
final float freq;
if (scorer instanceof SynonymScorer) {
freq = ((SynonymScorer) scorer).freq();
} else if (scorer instanceof FreqBoostTermScorer) {
freq = ((FreqBoostTermScorer) scorer).freq();
} else {
assert scorer instanceof TermScorer;
freq = ((TermScorer) scorer).freq();
}
LeafSimScorer docScorer = new LeafSimScorer(simWeight, context.reader(), terms[0].term.field(), true);
Explanation freqExplanation = Explanation.match(freq, "termFreq=" + freq);
Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
return Explanation.match(
scoreExplanation.getValue(),
"weight(" + getQuery() + " in " + doc + ") ["
+ similarity.getClass().getSimpleName() + "], result of:",
scoreExplanation);
}
}
return Explanation.noMatch("no matching term");
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
List<PostingsEnum> iterators = new ArrayList<>();
List<ImpactsEnum> impacts = new ArrayList<>();
List<Float> termBoosts = new ArrayList<> ();
for (int i = 0; i < terms.length; i++) {
TermState state = termStates[i].get(context);
if (state != null) {
TermsEnum termsEnum = context.reader().terms(terms[i].term.field()).iterator();
termsEnum.seekExact(terms[i].term.bytes(), state);
if (scoreMode == ScoreMode.TOP_SCORES) {
ImpactsEnum impactsEnum = termsEnum.impacts(PostingsEnum.FREQS);
iterators.add(impactsEnum);
impacts.add(impactsEnum);
} else {
PostingsEnum postingsEnum = termsEnum.postings(null, PostingsEnum.FREQS);
iterators.add(postingsEnum);
impacts.add(new SlowImpactsEnum(postingsEnum));
}
termBoosts.add(terms[i].boost);
}
}
if (iterators.isEmpty()) {
return null;
}
LeafSimScorer simScorer = new LeafSimScorer(simWeight, context.reader(), terms[0].term.field(), true);
// we must optimize this case (term not in segment), disjunctions require >= 2 subs
if (iterators.size() == 1) {
final TermScorer scorer;
if (scoreMode == ScoreMode.TOP_SCORES) {
scorer = new TermScorer(this, impacts.get(0), simScorer);
} else {
scorer = new TermScorer(this, iterators.get(0), simScorer);
}
float boost = termBoosts.get(0);
return scoreMode == ScoreMode.COMPLETE_NO_SCORES || boost == 1f ? scorer : new FreqBoostTermScorer(boost, scorer, simScorer);
}
// we use termscorers + disjunction as an impl detail
DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
for (int i = 0; i < iterators.size(); i++) {
PostingsEnum postings = iterators.get(i);
final TermScorer termScorer = new TermScorer(this, postings, simScorer);
float boost = termBoosts.get(i);
final DisiWrapperFreq wrapper = new DisiWrapperFreq(termScorer, boost);
queue.add(wrapper);
}
// Even though it is called approximation, it is accurate since none of
// the sub iterators are two-phase iterators.
DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue);
float[] boosts = new float[impacts.size()];
for (int i = 0; i < boosts.length; i++) {
boosts[i] = termBoosts.get(i);
}
ImpactsSource impactsSource = mergeImpacts(impacts.toArray(new ImpactsEnum[0]), boosts);
ImpactsDISI impactsDisi = new ImpactsDISI(iterator, impactsSource, simScorer.getSimScorer());
if (scoreMode == ScoreMode.TOP_SCORES) {
iterator = impactsDisi;
}
return new SynonymScorer(this, queue, iterator, impactsDisi, simScorer);
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
}
/**
* Merge impacts for multiple synonyms.
*/
static ImpactsSource mergeImpacts(ImpactsEnum[] impactsEnums, float[] boosts) {
assert impactsEnums.length == boosts.length;
return new ImpactsSource() {
class SubIterator {
final Iterator<Impact> iterator;
int previousFreq;
Impact current;
SubIterator(Iterator<Impact> iterator) {
this.iterator = iterator;
this.current = iterator.next();
}
void next() {
previousFreq = current.freq;
if (iterator.hasNext() == false) {
current = null;
} else {
current = iterator.next();
}
}
}
@Override
public Impacts getImpacts() throws IOException {
final Impacts[] impacts = new Impacts[impactsEnums.length];
// Use the impacts that have the lower next boundary as a lead.
// It will decide on the number of levels and the block boundaries.
Impacts tmpLead = null;
for (int i = 0; i < impactsEnums.length; ++i) {
impacts[i] = impactsEnums[i].getImpacts();
if (tmpLead == null || impacts[i].getDocIdUpTo(0) < tmpLead.getDocIdUpTo(0)) {
tmpLead = impacts[i];
}
}
final Impacts lead = tmpLead;
return new Impacts() {
@Override
public int numLevels() {
// Delegate to the lead
return lead.numLevels();
}
@Override
public int getDocIdUpTo(int level) {
// Delegate to the lead
return lead.getDocIdUpTo(level);
}
/**
* Return the minimum level whose impacts are valid up to {@code docIdUpTo},
* or {@code -1} if there is no such level.
*/
private int getLevel(Impacts impacts, int docIdUpTo) {
for (int level = 0, numLevels = impacts.numLevels(); level < numLevels; ++level) {
if (impacts.getDocIdUpTo(level) >= docIdUpTo) {
return level;
}
}
return -1;
}
@Override
public List<Impact> getImpacts(int level) {
final int docIdUpTo = getDocIdUpTo(level);
List<List<Impact>> toMerge = new ArrayList<>();
for (int i = 0; i < impactsEnums.length; ++i) {
if (impactsEnums[i].docID() <= docIdUpTo) {
int impactsLevel = getLevel(impacts[i], docIdUpTo);
if (impactsLevel == -1) {
// One instance doesn't have impacts that cover up to docIdUpTo
// Return impacts that trigger the maximum score
return Collections.singletonList(new Impact(Integer.MAX_VALUE, 1L));
}
final List<Impact> impactList;
if (boosts[i] != 1f) {
float boost = boosts[i];
impactList = impacts[i].getImpacts(impactsLevel)
.stream()
.map(impact -> new Impact((int) Math.ceil(impact.freq * boost), impact.norm))
.collect(Collectors.toList());
} else {
impactList = impacts[i].getImpacts(impactsLevel);
}
toMerge.add(impactList);
}
}
assert toMerge.size() > 0; // otherwise it would mean the docID is > docIdUpTo, which is wrong
if (toMerge.size() == 1) {
// common if one synonym is common and the other one is rare
return toMerge.get(0);
}
PriorityQueue<SubIterator> pq = new PriorityQueue<SubIterator>(impacts.length) {
@Override
protected boolean lessThan(SubIterator a, SubIterator b) {
if (a.current == null) { // means iteration is finished
return false;
}
if (b.current == null) {
return true;
}
return Long.compareUnsigned(a.current.norm, b.current.norm) < 0;
}
};
for (List<Impact> impacts : toMerge) {
pq.add(new SubIterator(impacts.iterator()));
}
List<Impact> mergedImpacts = new ArrayList<>();
// Idea: merge impacts by norm. The tricky thing is that we need to
// consider norm values that are not in the impacts too. For
// instance if the list of impacts is [{freq=2,norm=10}, {freq=4,norm=12}],
// there might well be a document that has a freq of 2 and a length of 11,
// which was just not added to the list of impacts because {freq=2,norm=10}
// is more competitive. So the way it works is that we track the sum of
// the term freqs that we have seen so far in order to account for these
// implicit impacts.
long sumTf = 0;
SubIterator top = pq.top();
do {
final long norm = top.current.norm;
do {
sumTf += top.current.freq - top.previousFreq;
top.next();
top = pq.updateTop();
} while (top.current != null && top.current.norm == norm);
final int freqUpperBound = (int) Math.min(Integer.MAX_VALUE, sumTf);
if (mergedImpacts.isEmpty()) {
mergedImpacts.add(new Impact(freqUpperBound, norm));
} else {
Impact prevImpact = mergedImpacts.get(mergedImpacts.size() - 1);
assert Long.compareUnsigned(prevImpact.norm, norm) < 0;
if (freqUpperBound > prevImpact.freq) {
mergedImpacts.add(new Impact(freqUpperBound, norm));
} // otherwise the previous impact is already more competitive
}
} while (top.current != null);
return mergedImpacts;
}
};
}
@Override
public void advanceShallow(int target) throws IOException {
for (ImpactsEnum impactsEnum : impactsEnums) {
if (impactsEnum.docID() < target) {
impactsEnum.advanceShallow(target);
}
}
}
};
}
private static class SynonymScorer extends Scorer {
private final DisiPriorityQueue queue;
private final DocIdSetIterator iterator;
private final ImpactsDISI impactsDisi;
private final LeafSimScorer simScorer;
SynonymScorer(Weight weight, DisiPriorityQueue queue, DocIdSetIterator iterator,
ImpactsDISI impactsDisi, LeafSimScorer simScorer) {
super(weight);
this.queue = queue;
this.iterator = iterator;
this.impactsDisi = impactsDisi;
this.simScorer = simScorer;
}
@Override
public int docID() {
return iterator.docID();
}
float freq() throws IOException {
DisiWrapperFreq w = (DisiWrapperFreq) queue.topList();
float freq = w.freq();
for (w = (DisiWrapperFreq) w.next; w != null; w = (DisiWrapperFreq) w.next) {
freq += w.freq();
}
return freq;
}
@Override
public float score() throws IOException {
return simScorer.score(iterator.docID(), freq());
}
@Override
public DocIdSetIterator iterator() {
return iterator;
}
@Override
public float getMaxScore(int upTo) throws IOException {
return impactsDisi.getMaxScore(upTo);
}
@Override
public int advanceShallow(int target) throws IOException {
return impactsDisi.advanceShallow(target);
}
@Override
public void setMinCompetitiveScore(float minScore) {
impactsDisi.setMinCompetitiveScore(minScore);
}
}
private static class DisiWrapperFreq extends DisiWrapper {
final PostingsEnum pe;
final float boost;
DisiWrapperFreq(Scorer scorer, float boost) {
super(scorer);
this.pe = (PostingsEnum) scorer.iterator();
this.boost = boost;
}
float freq() throws IOException {
return boost * pe.freq();
}
}
private static class FreqBoostTermScorer extends FilterScorer {
final float boost;
final TermScorer in;
final LeafSimScorer docScorer;
public FreqBoostTermScorer(float boost, TermScorer in, LeafSimScorer docScorer) {
super(in);
if (Float.isNaN(boost) || Float.compare(boost, 0f) < 0 || Float.compare(boost, 1f) > 0) {
throw new IllegalArgumentException("boost must be a positive float between 0 (exclusive) and 1 (inclusive)");
}
this.boost = boost;
this.in = in;
this.docScorer = docScorer;
}
float freq() throws IOException {
return boost * in.freq();
}
@Override
public float score() throws IOException {
assert docID() != DocIdSetIterator.NO_MORE_DOCS;
return docScorer.score(in.docID(), freq());
}
@Override
public float getMaxScore(int upTo) throws IOException {
return in.getMaxScore(upTo);
}
@Override
public int advanceShallow(int target) throws IOException {
return in.advanceShallow(target);
}
@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
in.setMinCompetitiveScore(minScore);
}
}
private static class TermAndBoost implements Comparable<TermAndBoost> {
final Term term;
final float boost;
TermAndBoost(Term term, float boost) {
this.term = term;
this.boost = boost;
}
Term getTerm() {
return term;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TermAndBoost that = (TermAndBoost) o;
return Float.compare(that.boost, boost) == 0 &&
Objects.equals(term, that.term);
}
@Override
public int hashCode() {
return Objects.hash(term, boost);
}
@Override
public int compareTo(TermAndBoost o) {
return term.compareTo(o.term);
}
}
}