blob: 3400f7b551dad670dc114fcc24e26e77f22402e5 [file] [log] [blame]
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
index 1b8a4e5..378c2af 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
@@ -17,18 +17,19 @@
package org.apache.lucene.document;
import java.io.IOException;
+import java.util.Objects;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.index.IndexOptions;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
@@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* <p>
* The constants in the above formulas typically need training in order to
* compute optimal values. If you don't know where to start, the
- * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
+ * {@link #newSaturationQuery(String, String)} method uses
* {@code 1f} as a weight and tries to guess a sensible value for the
* {@code pivot} parameter of the saturation function based on index
* statistics, which shouldn't perform too bad. Here is an example, assuming
@@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
* .build();
- * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ * Query boost = FeatureField.newSaturationQuery("features", "pagerank");
* Query boostedQuery = new BooleanQuery.Builder()
* .add(query, Occur.MUST)
* .add(boost, Occur.SHOULD)
@@ -210,6 +211,7 @@ public final class FeatureField extends Field {
static abstract class FeatureFunction {
abstract SimScorer scorer(String field, float w);
abstract Explanation explain(String field, String feature, float w, int freq);
+ FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
}
static final class LogFunction extends FeatureFunction {
@@ -263,24 +265,38 @@ public final class FeatureField extends Field {
static final class SaturationFunction extends FeatureFunction {
- private final float pivot;
+ private final String field, feature;
+ private final Float pivot;
- SaturationFunction(float pivot) {
+ SaturationFunction(String field, String feature, Float pivot) {
+ this.field = field;
+ this.feature = feature;
this.pivot = pivot;
}
@Override
+ public FeatureFunction rewrite(IndexReader reader) throws IOException {
+ if (pivot != null) {
+ return super.rewrite(reader);
+ }
+ float newPivot = computePivotFeatureValue(reader, field, feature);
+ return new SaturationFunction(field, feature, newPivot);
+ }
+
+ @Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
SaturationFunction that = (SaturationFunction) obj;
- return pivot == that.pivot;
+ return Objects.equals(field, that.field) &&
+ Objects.equals(feature, that.feature) &&
+ Objects.equals(pivot, that.pivot);
}
@Override
public int hashCode() {
- return Float.hashCode(pivot);
+ return Objects.hash(field, feature, pivot);
}
@Override
@@ -290,6 +306,10 @@ public final class FeatureField extends Field {
@Override
SimScorer scorer(String field, float weight) {
+ if (pivot == null) {
+ throw new IllegalStateException("Rewrite first");
+ }
+ final float pivot = this.pivot; // unbox
return new SimScorer(field) {
@Override
public float score(float freq, long norm) {
@@ -416,13 +436,30 @@ public final class FeatureField extends Field {
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
*/
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
+ return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
+ }
+
+ /**
+ * Same as {@link #newSaturationQuery(String, String, float, float)} but
+ * {@code 1f} is used as a weight and a reasonably good default pivot value
+ * is computed based on index statistics and is approximately equal to the
+ * geometric mean of all values that exist in the index.
+ * @param fieldName field that stores features
+ * @param featureName name of the feature
+ * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
+ */
+ public static Query newSaturationQuery(String fieldName, String featureName) {
+ return newSaturationQuery(fieldName, featureName, 1f, null);
+ }
+
+ private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
if (weight <= 0 || weight > MAX_WEIGHT) {
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
}
- if (pivot <= 0 || Float.isFinite(pivot) == false) {
+ if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
- Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
+ Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
if (weight != 1f) {
q = new BoostQuery(q, weight);
}
@@ -430,25 +467,6 @@ public final class FeatureField extends Field {
}
/**
- * Same as {@link #newSaturationQuery(String, String, float, float)} but
- * uses {@code 1f} as a weight and tries to compute a sensible default value
- * for {@code pivot} using
- * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
- * isn't expected to give an optimal configuration of these parameters but
- * should be a good start if you have no idea what the values of these
- * parameters should be.
- * @param searcher the {@link IndexSearcher} that you will search against
- * @param featureFieldName the field that stores features
- * @param featureName the name of the feature
- */
- public static Query newSaturationQuery(IndexSearcher searcher,
- String featureFieldName, String featureName) throws IOException {
- float weight = 1f;
- float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
- return newSaturationQuery(featureFieldName, featureName, weight, pivot);
- }
-
- /**
* Return a new {@link Query} that will score documents as
* {@code weight * S^a / (S^a + pivot^a)} where S is the value of the static feature.
* @param fieldName field that stores features
@@ -483,13 +501,20 @@ public final class FeatureField extends Field {
* representation in practice before converting it back to a float. Given that
* floats store the exponent in the higher bits, it means that the result will
* be an approximation of the geometric mean of all feature values.
- * @param searcher the {@link IndexSearcher} to search against
+ * @param reader the {@link IndexReader} to search against
* @param featureField the field that stores features
* @param featureName the name of the feature
*/
- public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
+ static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName);
- TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true);
+ TermStates states = TermStates.build(reader.getContext(), term, true);
+ if (states.docFreq() == 0) {
+ // avoid division by 0
+ // The return value doesn't matter much here, the term doesn't exist,
+ // it will never be used for scoring. Just Make sure to return a legal
+ // value.
+ return 1;
+ }
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
return decodeFeatureValue(avgFreq);
}
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
index 2b38712..add1b4a 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
@@ -22,6 +22,7 @@ import java.util.Set;
import org.apache.lucene.document.FeatureField.FeatureFunction;
import org.apache.lucene.index.ImpactsEnum;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
@@ -51,6 +52,15 @@ final class FeatureQuery extends Query {
}
@Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ FeatureFunction rewritten = function.rewrite(reader);
+ if (function != rewritten) {
+ return new FeatureQuery(fieldName, featureName, rewritten);
+ }
+ return super.rewrite(reader);
+ }
+
+ @Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
@@ -80,7 +90,16 @@ final class FeatureQuery extends Query {
}
@Override
- public void extractTerms(Set<Term> terms) {}
+ public void extractTerms(Set<Term> terms) {
+ if (scoreMode.needsScores() == false) {
+ // features are irrelevant to highlighting, skip
+ } else {
+ // extracting the term here will help get better scoring with
+ // distributed term statistics if the saturation function is used
+ // and the pivot value is computed automatically
+ terms.add(new Term(fieldName, featureName));
+ }
+ }
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
index 2afc250..312abdc 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
@@ -17,10 +17,15 @@
package org.apache.lucene.document;
import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.MultiReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
@@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase {
}
public void testSatuSimScorer() {
- doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f));
+ doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
}
public void testSigmSimScorer() {
@@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase {
public void testComputePivotFeatureValue() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
+
+ // Make sure that we create a legal pivot on missing features
+ DirectoryReader reader = writer.getReader();
+ float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
+ assertTrue(Float.isFinite(pivot));
+ assertTrue(pivot > 0);
+ reader.close();
+
Document doc = new Document();
FeatureField pagerank = new FeatureField("features", "pagerank", 1);
doc.add(pagerank);
@@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase {
pagerank.setFeatureValue(42);
writer.addDocument(doc);
- DirectoryReader reader = writer.getReader();
+ reader = writer.getReader();
writer.close();
- IndexSearcher searcher = new IndexSearcher(reader);
- float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
+ pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
assertEquals(expected, pivot, 0.1);
@@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase {
dir.close();
}
+ public void testExtractTerms() throws IOException {
+ IndexReader reader = new MultiReader();
+ IndexSearcher searcher = newSearcher(reader);
+ Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
+
+ Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f);
+ Set<Term> terms = new HashSet<>();
+ weight.extractTerms(terms);
+ assertEquals(Collections.emptySet(), terms);
+
+ terms = new HashSet<>();
+ weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f);
+ weight.extractTerms(terms);
+ assertEquals(Collections.singleton(new Term("field", "term")), terms);
+
+ terms = new HashSet<>();
+ weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f);
+ weight.extractTerms(terms);
+ assertEquals(Collections.singleton(new Term("field", "term")), terms);
+ }
+
public void testDemo() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
@@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase {
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
.build();
- Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ Query boost = FeatureField.newSaturationQuery("features", "pagerank");
Query boostedQuery = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(boost, Occur.SHOULD)