| diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java |
| index 1b8a4e5..378c2af 100644 |
| --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java |
| +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java |
| @@ -17,18 +17,19 @@ |
| package org.apache.lucene.document; |
| |
| import java.io.IOException; |
| +import java.util.Objects; |
| |
| import org.apache.lucene.analysis.Analyzer; |
| import org.apache.lucene.analysis.TokenStream; |
| import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; |
| import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; |
| import org.apache.lucene.index.IndexOptions; |
| +import org.apache.lucene.index.IndexReader; |
| import org.apache.lucene.index.Term; |
| import org.apache.lucene.index.TermStates; |
| import org.apache.lucene.search.BooleanQuery; |
| import org.apache.lucene.search.BoostQuery; |
| import org.apache.lucene.search.Explanation; |
| -import org.apache.lucene.search.IndexSearcher; |
| import org.apache.lucene.search.Query; |
| import org.apache.lucene.search.similarities.BM25Similarity; |
| import org.apache.lucene.search.similarities.Similarity.SimScorer; |
| @@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; |
| * <p> |
| * The constants in the above formulas typically need training in order to |
| * compute optimal values. If you don't know where to start, the |
| - * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses |
| + * {@link #newSaturationQuery(String, String)} method uses |
| * {@code 1f} as a weight and tries to guess a sensible value for the |
| * {@code pivot} parameter of the saturation function based on index |
| * statistics, which shouldn't perform too bad. Here is an example, assuming |
| @@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; |
| * .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) |
| * .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) |
| * .build(); |
| - * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); |
| + * Query boost = FeatureField.newSaturationQuery("features", "pagerank"); |
| * Query boostedQuery = new BooleanQuery.Builder() |
| * .add(query, Occur.MUST) |
| * .add(boost, Occur.SHOULD) |
| @@ -210,6 +211,7 @@ public final class FeatureField extends Field { |
| static abstract class FeatureFunction { |
| abstract SimScorer scorer(String field, float w); |
| abstract Explanation explain(String field, String feature, float w, int freq); |
| + FeatureFunction rewrite(IndexReader reader) throws IOException { return this; } |
| } |
| |
| static final class LogFunction extends FeatureFunction { |
| @@ -263,24 +265,38 @@ public final class FeatureField extends Field { |
| |
| static final class SaturationFunction extends FeatureFunction { |
| |
| - private final float pivot; |
| + private final String field, feature; |
| + private final Float pivot; |
| |
| - SaturationFunction(float pivot) { |
| + SaturationFunction(String field, String feature, Float pivot) { |
| + this.field = field; |
| + this.feature = feature; |
| this.pivot = pivot; |
| } |
| |
| @Override |
| + public FeatureFunction rewrite(IndexReader reader) throws IOException { |
| + if (pivot != null) { |
| + return super.rewrite(reader); |
| + } |
| + float newPivot = computePivotFeatureValue(reader, field, feature); |
| + return new SaturationFunction(field, feature, newPivot); |
| + } |
| + |
| + @Override |
| public boolean equals(Object obj) { |
| if (obj == null || getClass() != obj.getClass()) { |
| return false; |
| } |
| SaturationFunction that = (SaturationFunction) obj; |
| - return pivot == that.pivot; |
| + return Objects.equals(field, that.field) && |
| + Objects.equals(feature, that.feature) && |
| + Objects.equals(pivot, that.pivot); |
| } |
| |
| @Override |
| public int hashCode() { |
| - return Float.hashCode(pivot); |
| + return Objects.hash(field, feature, pivot); |
| } |
| |
| @Override |
| @@ -290,6 +306,10 @@ public final class FeatureField extends Field { |
| |
| @Override |
| SimScorer scorer(String field, float weight) { |
| + if (pivot == null) { |
| + throw new IllegalStateException("Rewrite first"); |
| + } |
| + final float pivot = this.pivot; // unbox |
| return new SimScorer(field) { |
| @Override |
| public float score(float freq, long norm) { |
| @@ -416,13 +436,30 @@ public final class FeatureField extends Field { |
| * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity) |
| */ |
| public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) { |
| + return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot)); |
| + } |
| + |
| + /** |
| + * Same as {@link #newSaturationQuery(String, String, float, float)} but |
| + * {@code 1f} is used as a weight and a reasonably good default pivot value |
| + * is computed based on index statistics and is approximately equal to the |
| + * geometric mean of all values that exist in the index. |
| + * @param fieldName field that stores features |
| + * @param featureName name of the feature |
| + * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity) |
| + */ |
| + public static Query newSaturationQuery(String fieldName, String featureName) { |
| + return newSaturationQuery(fieldName, featureName, 1f, null); |
| + } |
| + |
| + private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) { |
| if (weight <= 0 || weight > MAX_WEIGHT) { |
| throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight); |
| } |
| - if (pivot <= 0 || Float.isFinite(pivot) == false) { |
| + if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) { |
| throw new IllegalArgumentException("pivot must be > 0, got: " + pivot); |
| } |
| - Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot)); |
| + Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot)); |
| if (weight != 1f) { |
| q = new BoostQuery(q, weight); |
| } |
| @@ -430,25 +467,6 @@ public final class FeatureField extends Field { |
| } |
| |
| /** |
| - * Same as {@link #newSaturationQuery(String, String, float, float)} but |
| - * uses {@code 1f} as a weight and tries to compute a sensible default value |
| - * for {@code pivot} using |
| - * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This |
| - * isn't expected to give an optimal configuration of these parameters but |
| - * should be a good start if you have no idea what the values of these |
| - * parameters should be. |
| - * @param searcher the {@link IndexSearcher} that you will search against |
| - * @param featureFieldName the field that stores features |
| - * @param featureName the name of the feature |
| - */ |
| - public static Query newSaturationQuery(IndexSearcher searcher, |
| - String featureFieldName, String featureName) throws IOException { |
| - float weight = 1f; |
| - float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName); |
| - return newSaturationQuery(featureFieldName, featureName, weight, pivot); |
| - } |
| - |
| - /** |
| * Return a new {@link Query} that will score documents as |
| * {@code weight * S^a / (S^a + pivot^a)} where S is the value of the static feature. |
| * @param fieldName field that stores features |
| @@ -483,13 +501,20 @@ public final class FeatureField extends Field { |
| * representation in practice before converting it back to a float. Given that |
| * floats store the exponent in the higher bits, it means that the result will |
| * be an approximation of the geometric mean of all feature values. |
| - * @param searcher the {@link IndexSearcher} to search against |
| + * @param reader the {@link IndexReader} to search against |
| * @param featureField the field that stores features |
| * @param featureName the name of the feature |
| */ |
| - public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException { |
| + static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException { |
| Term term = new Term(featureField, featureName); |
| - TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true); |
| + TermStates states = TermStates.build(reader.getContext(), term, true); |
| + if (states.docFreq() == 0) { |
| + // avoid division by 0 |
| + // The return value doesn't matter much here, the term doesn't exist, |
| + // it will never be used for scoring. Just Make sure to return a legal |
| + // value. |
| + return 1; |
| + } |
| float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq()); |
| return decodeFeatureValue(avgFreq); |
| } |
| diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java |
| index 2b38712..add1b4a 100644 |
| --- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java |
| +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java |
| @@ -22,6 +22,7 @@ import java.util.Set; |
| |
| import org.apache.lucene.document.FeatureField.FeatureFunction; |
| import org.apache.lucene.index.ImpactsEnum; |
| +import org.apache.lucene.index.IndexReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| import org.apache.lucene.index.PostingsEnum; |
| import org.apache.lucene.index.Term; |
| @@ -51,6 +52,15 @@ final class FeatureQuery extends Query { |
| } |
| |
| @Override |
| + public Query rewrite(IndexReader reader) throws IOException { |
| + FeatureFunction rewritten = function.rewrite(reader); |
| + if (function != rewritten) { |
| + return new FeatureQuery(fieldName, featureName, rewritten); |
| + } |
| + return super.rewrite(reader); |
| + } |
| + |
| + @Override |
| public boolean equals(Object obj) { |
| if (obj == null || getClass() != obj.getClass()) { |
| return false; |
| @@ -80,7 +90,16 @@ final class FeatureQuery extends Query { |
| } |
| |
| @Override |
| - public void extractTerms(Set<Term> terms) {} |
| + public void extractTerms(Set<Term> terms) { |
| + if (scoreMode.needsScores() == false) { |
| + // features are irrelevant to highlighting, skip |
| + } else { |
| + // extracting the term here will help get better scoring with |
| + // distributed term statistics if the saturation function is used |
| + // and the pivot value is computed automatically |
| + terms.add(new Term(fieldName, featureName)); |
| + } |
| + } |
| |
| @Override |
| public Explanation explain(LeafReaderContext context, int doc) throws IOException { |
| diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java |
| index 2afc250..312abdc 100644 |
| --- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java |
| +++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java |
| @@ -17,10 +17,15 @@ |
| package org.apache.lucene.document; |
| |
| import java.io.IOException; |
| +import java.util.Collections; |
| +import java.util.HashSet; |
| +import java.util.Set; |
| |
| import org.apache.lucene.document.Field.Store; |
| import org.apache.lucene.index.DirectoryReader; |
| +import org.apache.lucene.index.IndexReader; |
| import org.apache.lucene.index.LeafReaderContext; |
| +import org.apache.lucene.index.MultiReader; |
| import org.apache.lucene.index.RandomIndexWriter; |
| import org.apache.lucene.index.Term; |
| import org.apache.lucene.search.BooleanClause.Occur; |
| @@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase { |
| } |
| |
| public void testSatuSimScorer() { |
| - doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f)); |
| + doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f)); |
| } |
| |
| public void testSigmSimScorer() { |
| @@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase { |
| public void testComputePivotFeatureValue() throws IOException { |
| Directory dir = newDirectory(); |
| RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); |
| + |
| + // Make sure that we create a legal pivot on missing features |
| + DirectoryReader reader = writer.getReader(); |
| + float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); |
| + assertTrue(Float.isFinite(pivot)); |
| + assertTrue(pivot > 0); |
| + reader.close(); |
| + |
| Document doc = new Document(); |
| FeatureField pagerank = new FeatureField("features", "pagerank", 1); |
| doc.add(pagerank); |
| @@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase { |
| pagerank.setFeatureValue(42); |
| writer.addDocument(doc); |
| |
| - DirectoryReader reader = writer.getReader(); |
| + reader = writer.getReader(); |
| writer.close(); |
| |
| - IndexSearcher searcher = new IndexSearcher(reader); |
| - float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank"); |
| + pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); |
| double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean |
| assertEquals(expected, pivot, 0.1); |
| |
| @@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase { |
| dir.close(); |
| } |
| |
| + public void testExtractTerms() throws IOException { |
| + IndexReader reader = new MultiReader(); |
| + IndexSearcher searcher = newSearcher(reader); |
| + Query query = FeatureField.newLogQuery("field", "term", 2f, 42); |
| + |
| + Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f); |
| + Set<Term> terms = new HashSet<>(); |
| + weight.extractTerms(terms); |
| + assertEquals(Collections.emptySet(), terms); |
| + |
| + terms = new HashSet<>(); |
| + weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f); |
| + weight.extractTerms(terms); |
| + assertEquals(Collections.singleton(new Term("field", "term")), terms); |
| + |
| + terms = new HashSet<>(); |
| + weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f); |
| + weight.extractTerms(terms); |
| + assertEquals(Collections.singleton(new Term("field", "term")), terms); |
| + } |
| + |
| public void testDemo() throws IOException { |
| Directory dir = newDirectory(); |
| RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig() |
| @@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase { |
| .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) |
| .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) |
| .build(); |
| - Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); |
| + Query boost = FeatureField.newSaturationQuery("features", "pagerank"); |
| Query boostedQuery = new BooleanQuery.Builder() |
| .add(query, Occur.MUST) |
| .add(boost, Occur.SHOULD) |