| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.search; |
| |
| |
| import java.io.IOException; |
| import java.util.Arrays; |
| import java.util.List; |
| |
| import org.apache.lucene.index.IndexReader; |
| import org.apache.lucene.index.IndexReaderContext; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.Term; |
| import org.apache.lucene.index.TermState; |
| import org.apache.lucene.index.TermStates; |
| import org.apache.lucene.search.BooleanClause.Occur; |
| import org.apache.lucene.util.ArrayUtil; |
| import org.apache.lucene.util.InPlaceMergeSorter; |
| |
| /** |
| * A {@link Query} that blends index statistics across multiple terms. |
| * This is particularly useful when several terms should produce identical |
| * scores, regardless of their index statistics. |
| * <p>For instance imagine that you are resolving synonyms at search time, |
| * all terms should produce identical scores instead of the default behavior, |
| * which tends to give higher scores to rare terms. |
| * <p>An other useful use-case is cross-field search: imagine that you would |
| * like to search for {@code john} on two fields: {@code first_name} and |
| * {@code last_name}. You might not want to give a higher weight to matches |
| * on the field where {@code john} is rarer, in which case |
| * {@link BlendedTermQuery} would help as well. |
| * @lucene.experimental |
| */ |
| public final class BlendedTermQuery extends Query { |
| |
| /** A Builder for {@link BlendedTermQuery}. */ |
| public static class Builder { |
| |
| private int numTerms = 0; |
| private Term[] terms = new Term[0]; |
| private float[] boosts = new float[0]; |
| private TermStates[] contexts = new TermStates[0]; |
| private RewriteMethod rewriteMethod = DISJUNCTION_MAX_REWRITE; |
| |
| /** Sole constructor. */ |
| public Builder() {} |
| |
| /** Set the {@link RewriteMethod}. Default is to use |
| * {@link BlendedTermQuery#DISJUNCTION_MAX_REWRITE}. |
| * @see RewriteMethod */ |
| public Builder setRewriteMethod(RewriteMethod rewiteMethod) { |
| this.rewriteMethod = rewiteMethod; |
| return this; |
| } |
| |
| /** Add a new {@link Term} to this builder, with a default boost of {@code 1}. |
| * @see #add(Term, float) */ |
| public Builder add(Term term) { |
| return add(term, 1f); |
| } |
| |
| /** Add a {@link Term} with the provided boost. The higher the boost, the |
| * more this term will contribute to the overall score of the |
| * {@link BlendedTermQuery}. */ |
| public Builder add(Term term, float boost) { |
| return add(term, boost, null); |
| } |
| |
| /** |
| * Expert: Add a {@link Term} with the provided boost and context. |
| * This method is useful if you already have a {@link TermStates} |
| * object constructed for the given term. |
| */ |
| public Builder add(Term term, float boost, TermStates context) { |
| if (numTerms >= BooleanQuery.getMaxClauseCount()) { |
| throw new BooleanQuery.TooManyClauses(); |
| } |
| terms = ArrayUtil.grow(terms, numTerms + 1); |
| boosts = ArrayUtil.grow(boosts, numTerms + 1); |
| contexts = ArrayUtil.grow(contexts, numTerms + 1); |
| terms[numTerms] = term; |
| boosts[numTerms] = boost; |
| contexts[numTerms] = context; |
| numTerms += 1; |
| return this; |
| } |
| |
| /** Build the {@link BlendedTermQuery}. */ |
| public BlendedTermQuery build() { |
| return new BlendedTermQuery( |
| ArrayUtil.copyOfSubArray(terms, 0, numTerms), |
| ArrayUtil.copyOfSubArray(boosts, 0, numTerms), |
| ArrayUtil.copyOfSubArray(contexts, 0, numTerms), |
| rewriteMethod); |
| } |
| |
| } |
| |
| /** A {@link RewriteMethod} defines how queries for individual terms should |
| * be merged. |
| * @lucene.experimental |
| * @see BlendedTermQuery#BOOLEAN_REWRITE |
| * @see BlendedTermQuery.DisjunctionMaxRewrite */ |
| public static abstract class RewriteMethod { |
| |
| /** Sole constructor */ |
| protected RewriteMethod() {} |
| |
| /** Merge the provided sub queries into a single {@link Query} object. */ |
| public abstract Query rewrite(Query[] subQueries); |
| |
| } |
| |
| /** |
| * A {@link RewriteMethod} that adds all sub queries to a {@link BooleanQuery}. |
| * This {@link RewriteMethod} is useful when matching on several fields is |
| * considered better than having a good match on a single field. |
| */ |
| public static final RewriteMethod BOOLEAN_REWRITE = new RewriteMethod() { |
| @Override |
| public Query rewrite(Query[] subQueries) { |
| BooleanQuery.Builder merged = new BooleanQuery.Builder(); |
| for (Query query : subQueries) { |
| merged.add(query, Occur.SHOULD); |
| } |
| return merged.build(); |
| } |
| }; |
| |
| /** |
| * A {@link RewriteMethod} that creates a {@link DisjunctionMaxQuery} out |
| * of the sub queries. This {@link RewriteMethod} is useful when having a |
| * good match on a single field is considered better than having average |
| * matches on several fields. |
| */ |
| public static class DisjunctionMaxRewrite extends RewriteMethod { |
| |
| private final float tieBreakerMultiplier; |
| |
| /** This {@link RewriteMethod} will create {@link DisjunctionMaxQuery} |
| * instances that have the provided tie breaker. |
| * @see DisjunctionMaxQuery */ |
| public DisjunctionMaxRewrite(float tieBreakerMultiplier) { |
| this.tieBreakerMultiplier = tieBreakerMultiplier; |
| } |
| |
| @Override |
| public Query rewrite(Query[] subQueries) { |
| return new DisjunctionMaxQuery(Arrays.asList(subQueries), tieBreakerMultiplier); |
| } |
| |
| @Override |
| public boolean equals(Object obj) { |
| if (obj == null || getClass() != obj.getClass()) { |
| return false; |
| } |
| DisjunctionMaxRewrite that = (DisjunctionMaxRewrite) obj; |
| return tieBreakerMultiplier == that.tieBreakerMultiplier; |
| } |
| |
| @Override |
| public int hashCode() { |
| return 31 * getClass().hashCode() + Float.floatToIntBits(tieBreakerMultiplier); |
| } |
| |
| } |
| |
| /** {@link DisjunctionMaxRewrite} instance with a tie-breaker of {@code 0.01}. */ |
| public static final RewriteMethod DISJUNCTION_MAX_REWRITE = new DisjunctionMaxRewrite(0.01f); |
| |
| private final Term[] terms; |
| private final float[] boosts; |
| private final TermStates[] contexts; |
| private final RewriteMethod rewriteMethod; |
| |
| private BlendedTermQuery(Term[] terms, float[] boosts, TermStates[] contexts, |
| RewriteMethod rewriteMethod) { |
| assert terms.length == boosts.length; |
| assert terms.length == contexts.length; |
| this.terms = terms; |
| this.boosts = boosts; |
| this.contexts = contexts; |
| this.rewriteMethod = rewriteMethod; |
| |
| // we sort terms so that equals/hashcode does not rely on the order |
| new InPlaceMergeSorter() { |
| |
| @Override |
| protected void swap(int i, int j) { |
| Term tmpTerm = terms[i]; |
| terms[i] = terms[j]; |
| terms[j] = tmpTerm; |
| |
| TermStates tmpContext = contexts[i]; |
| contexts[i] = contexts[j]; |
| contexts[j] = tmpContext; |
| |
| float tmpBoost = boosts[i]; |
| boosts[i] = boosts[j]; |
| boosts[j] = tmpBoost; |
| } |
| |
| @Override |
| protected int compare(int i, int j) { |
| return terms[i].compareTo(terms[j]); |
| } |
| }.sort(0, terms.length); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return sameClassAs(other) && |
| equalsTo(getClass().cast(other)); |
| } |
| |
| private boolean equalsTo(BlendedTermQuery other) { |
| return Arrays.equals(terms, other.terms) && |
| Arrays.equals(contexts, other.contexts) && |
| Arrays.equals(boosts, other.boosts) && |
| rewriteMethod.equals(other.rewriteMethod); |
| } |
| |
| @Override |
| public int hashCode() { |
| int h = classHash(); |
| h = 31 * h + Arrays.hashCode(terms); |
| h = 31 * h + Arrays.hashCode(contexts); |
| h = 31 * h + Arrays.hashCode(boosts); |
| h = 31 * h + rewriteMethod.hashCode(); |
| return h; |
| } |
| |
| @Override |
| public String toString(String field) { |
| StringBuilder builder = new StringBuilder("Blended("); |
| for (int i = 0; i < terms.length; ++i) { |
| if (i != 0) { |
| builder.append(" "); |
| } |
| Query termQuery = new TermQuery(terms[i]); |
| if (boosts[i] != 1f) { |
| termQuery = new BoostQuery(termQuery, boosts[i]); |
| } |
| builder.append(termQuery.toString(field)); |
| } |
| builder.append(")"); |
| return builder.toString(); |
| } |
| |
| @Override |
| public final Query rewrite(IndexReader reader) throws IOException { |
| final TermStates[] contexts = ArrayUtil.copyOfSubArray(this.contexts, 0, this.contexts.length); |
| for (int i = 0; i < contexts.length; ++i) { |
| if (contexts[i] == null || contexts[i].wasBuiltFor(reader.getContext()) == false) { |
| contexts[i] = TermStates.build(reader.getContext(), terms[i], true); |
| } |
| } |
| |
| // Compute aggregated doc freq and total term freq |
| // df will be the max of all doc freqs |
| // ttf will be the sum of all total term freqs |
| int df = 0; |
| long ttf = 0; |
| for (TermStates ctx : contexts) { |
| df = Math.max(df, ctx.docFreq()); |
| ttf += ctx.totalTermFreq(); |
| } |
| |
| for (int i = 0; i < contexts.length; ++i) { |
| contexts[i] = adjustFrequencies(reader.getContext(), contexts[i], df, ttf); |
| } |
| |
| Query[] termQueries = new Query[terms.length]; |
| for (int i = 0; i < terms.length; ++i) { |
| termQueries[i] = new TermQuery(terms[i], contexts[i]); |
| if (boosts[i] != 1f) { |
| termQueries[i] = new BoostQuery(termQueries[i], boosts[i]); |
| } |
| } |
| return rewriteMethod.rewrite(termQueries); |
| } |
| |
| @Override |
| public void visit(QueryVisitor visitor) { |
| Term[] termsToVisit = Arrays.stream(terms).filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new); |
| if (termsToVisit.length > 0) { |
| QueryVisitor v = visitor.getSubVisitor(Occur.SHOULD, this); |
| v.consumeTerms(this, termsToVisit); |
| } |
| } |
| |
| private static TermStates adjustFrequencies(IndexReaderContext readerContext, |
| TermStates ctx, int artificialDf, long artificialTtf) throws IOException { |
| List<LeafReaderContext> leaves = readerContext.leaves(); |
| final int len; |
| if (leaves == null) { |
| len = 1; |
| } else { |
| len = leaves.size(); |
| } |
| TermStates newCtx = new TermStates(readerContext); |
| for (int i = 0; i < len; ++i) { |
| TermState termState = ctx.get(leaves.get(i)); |
| if (termState == null) { |
| continue; |
| } |
| newCtx.register(termState, i); |
| } |
| newCtx.accumulateStatistics(artificialDf, artificialTtf); |
| return newCtx; |
| } |
| |
| } |