| diff --git a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java |
| index d4be4e9..af24e1a 100644 |
| --- a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java |
| +++ b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java |
| @@ -20,7 +20,9 @@ package org.apache.lucene.search; |
| import java.io.IOException; |
| import java.util.Objects; |
| import java.util.function.DoubleToLongFunction; |
| +import java.util.function.DoubleUnaryOperator; |
| import java.util.function.LongToDoubleFunction; |
| +import java.util.function.ToDoubleBiFunction; |
| |
| import org.apache.lucene.index.DocValues; |
| import org.apache.lucene.index.LeafReaderContext; |
| @@ -173,6 +175,69 @@ public abstract class DoubleValuesSource { |
| public boolean needsScores() { |
| return false; |
| } |
| + |
| + @Override |
| + public String toString() { |
| + return "constant(" + value + ")"; |
| + } |
| + }; |
| + } |
| + |
| + /** |
| + * Creates a DoubleValuesSource that is a function of another DoubleValuesSource |
| + */ |
| + public static DoubleValuesSource function(DoubleValuesSource in, DoubleUnaryOperator function) { |
| + return new DoubleValuesSource() { |
| + @Override |
| + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { |
| + DoubleValues inputs = in.getValues(ctx, scores); |
| + return new DoubleValues() { |
| + @Override |
| + public double doubleValue() throws IOException { |
| + return function.applyAsDouble(inputs.doubleValue()); |
| + } |
| + |
| + @Override |
| + public boolean advanceExact(int doc) throws IOException { |
| + return inputs.advanceExact(doc); |
| + } |
| + }; |
| + } |
| + |
| + @Override |
| + public boolean needsScores() { |
| + return in.needsScores(); |
| + } |
| + }; |
| + } |
| + |
| + /** |
| + * Creates a DoubleValuesSource that is a function of another DoubleValuesSource and a score |
| + * @param in the DoubleValuesSource to use as an input |
| + * @param function a function of the form (source, score) == result |
| + */ |
| + public static DoubleValuesSource scoringFunction(DoubleValuesSource in, ToDoubleBiFunction<Double, Double> function) { |
| + return new DoubleValuesSource() { |
| + @Override |
| + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { |
| + DoubleValues inputs = in.getValues(ctx, scores); |
| + return new DoubleValues() { |
| + @Override |
| + public double doubleValue() throws IOException { |
| + return function.applyAsDouble(inputs.doubleValue(), scores.doubleValue()); |
| + } |
| + |
| + @Override |
| + public boolean advanceExact(int doc) throws IOException { |
| + return inputs.advanceExact(doc); |
| + } |
| + }; |
| + } |
| + |
| + @Override |
| + public boolean needsScores() { |
| + return true; |
| + } |
| }; |
| } |
| |
| @@ -221,7 +286,17 @@ public abstract class DoubleValuesSource { |
| @Override |
| public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { |
| final NumericDocValues values = DocValues.getNumeric(ctx.reader(), field); |
| - return toDoubleValues(values, decoder::applyAsDouble); |
| + return new DoubleValues() { |
| + @Override |
| + public double doubleValue() throws IOException { |
| + return decoder.applyAsDouble(values.longValue()); |
| + } |
| + |
| + @Override |
| + public boolean advanceExact(int target) throws IOException { |
| + return values.advanceExact(target); |
| + } |
| + }; |
| } |
| |
| @Override |
| @@ -288,21 +363,6 @@ public abstract class DoubleValuesSource { |
| } |
| } |
| |
| - private static DoubleValues toDoubleValues(NumericDocValues in, LongToDoubleFunction map) { |
| - return new DoubleValues() { |
| - @Override |
| - public double doubleValue() throws IOException { |
| - return map.applyAsDouble(in.longValue()); |
| - } |
| - |
| - @Override |
| - public boolean advanceExact(int target) throws IOException { |
| - return in.advanceExact(target); |
| - } |
| - |
| - }; |
| - } |
| - |
| private static NumericDocValues asNumericDocValues(DoubleValuesHolder in, DoubleToLongFunction converter) { |
| return new NumericDocValues() { |
| @Override |
| diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionMatchQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionMatchQuery.java |
| new file mode 100644 |
| index 0000000..4a9c709 |
| --- /dev/null |
| +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionMatchQuery.java |
| @@ -0,0 +1,99 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.lucene.queries.function; |
| + |
| +import java.io.IOException; |
| +import java.util.Objects; |
| +import java.util.function.DoublePredicate; |
| + |
| +import org.apache.lucene.index.LeafReaderContext; |
| +import org.apache.lucene.search.ConstantScoreScorer; |
| +import org.apache.lucene.search.ConstantScoreWeight; |
| +import org.apache.lucene.search.DocIdSetIterator; |
| +import org.apache.lucene.search.DoubleValues; |
| +import org.apache.lucene.search.DoubleValuesSource; |
| +import org.apache.lucene.search.IndexSearcher; |
| +import org.apache.lucene.search.Query; |
| +import org.apache.lucene.search.Scorer; |
| +import org.apache.lucene.search.TwoPhaseIterator; |
| +import org.apache.lucene.search.Weight; |
| + |
| +/** |
| + * A query that retrieves all documents with a {@link DoubleValues} value matching a predicate |
| + * |
| + * This query works by a linear scan of the index, and is best used in |
| + * conjunction with other queries that can restrict the number of |
| + * documents visited |
| + */ |
| +public final class FunctionMatchQuery extends Query { |
| + |
| + private final DoubleValuesSource source; |
| + private final DoublePredicate filter; |
| + |
| + /** |
| + * Create a FunctionMatchQuery |
| + * @param source a {@link DoubleValuesSource} to use for values |
| + * @param filter the predicate to match against |
| + */ |
| + public FunctionMatchQuery(DoubleValuesSource source, DoublePredicate filter) { |
| + this.source = source; |
| + this.filter = filter; |
| + } |
| + |
| + @Override |
| + public String toString(String field) { |
| + return "FunctionMatchQuery(" + source.toString() + ")"; |
| + } |
| + |
| + @Override |
| + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { |
| + return new ConstantScoreWeight(this, boost) { |
| + @Override |
| + public Scorer scorer(LeafReaderContext context) throws IOException { |
| + DoubleValues values = source.getValues(context, null); |
| + DocIdSetIterator approximation = DocIdSetIterator.all(context.reader().maxDoc()); |
| + TwoPhaseIterator twoPhase = new TwoPhaseIterator(approximation) { |
| + @Override |
| + public boolean matches() throws IOException { |
| + return values.advanceExact(approximation.docID()) && filter.test(values.doubleValue()); |
| + } |
| + |
| + @Override |
| + public float matchCost() { |
| + return 100; // TODO maybe DoubleValuesSource should have a matchCost? |
| + } |
| + }; |
| + return new ConstantScoreScorer(this, score(), twoPhase); |
| + } |
| + }; |
| + } |
| + |
| + @Override |
| + public boolean equals(Object o) { |
| + if (this == o) return true; |
| + if (o == null || getClass() != o.getClass()) return false; |
| + FunctionMatchQuery that = (FunctionMatchQuery) o; |
| + return Objects.equals(source, that.source) && Objects.equals(filter, that.filter); |
| + } |
| + |
| + @Override |
| + public int hashCode() { |
| + return Objects.hash(source, filter); |
| + } |
| + |
| +} |
| diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java |
| new file mode 100644 |
| index 0000000..29ef41f |
| --- /dev/null |
| +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java |
| @@ -0,0 +1,151 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.lucene.queries.function; |
| + |
| +import java.io.IOException; |
| +import java.util.Objects; |
| +import java.util.Set; |
| + |
| +import org.apache.lucene.index.IndexReader; |
| +import org.apache.lucene.index.LeafReaderContext; |
| +import org.apache.lucene.index.Term; |
| +import org.apache.lucene.search.DoubleValues; |
| +import org.apache.lucene.search.DoubleValuesSource; |
| +import org.apache.lucene.search.Explanation; |
| +import org.apache.lucene.search.FilterScorer; |
| +import org.apache.lucene.search.IndexSearcher; |
| +import org.apache.lucene.search.Query; |
| +import org.apache.lucene.search.Scorer; |
| +import org.apache.lucene.search.Weight; |
| + |
| +/** |
| + * A query that wraps another query, and uses a DoubleValuesSource to |
| + * replace or modify the wrapped query's score |
| + * |
| + * If the DoubleValuesSource doesn't return a value for a particular document, |
| + * then that document will be given a score of 0. |
| + */ |
| +public final class FunctionScoreQuery extends Query { |
| + |
| + private final Query in; |
| + private final DoubleValuesSource source; |
| + |
| + /** |
| + * Create a new FunctionScoreQuery |
| + * @param in the query to wrap |
| + * @param source a source of scores |
| + */ |
| + public FunctionScoreQuery(Query in, DoubleValuesSource source) { |
| + this.in = in; |
| + this.source = source; |
| + } |
| + |
| + @Override |
| + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { |
| + Weight inner = in.createWeight(searcher, needsScores && source.needsScores(), 1f); |
| + if (needsScores == false) |
| + return inner; |
| + return new FunctionScoreWeight(this, inner, source, boost); |
| + } |
| + |
| + @Override |
| + public Query rewrite(IndexReader reader) throws IOException { |
| + Query rewritten = in.rewrite(reader); |
| + if (rewritten == in) |
| + return this; |
| + return new FunctionScoreQuery(rewritten, source); |
| + } |
| + |
| + @Override |
| + public String toString(String field) { |
| + return "FunctionScoreQuery(" + in.toString(field) + ", scored by " + source.toString() + ")"; |
| + } |
| + |
| + @Override |
| + public boolean equals(Object o) { |
| + if (this == o) return true; |
| + if (o == null || getClass() != o.getClass()) return false; |
| + FunctionScoreQuery that = (FunctionScoreQuery) o; |
| + return Objects.equals(in, that.in) && |
| + Objects.equals(source, that.source); |
| + } |
| + |
| + @Override |
| + public int hashCode() { |
| + return Objects.hash(in, source); |
| + } |
| + |
| + private static class FunctionScoreWeight extends Weight { |
| + |
| + final Weight inner; |
| + final DoubleValuesSource valueSource; |
| + final float boost; |
| + |
| + FunctionScoreWeight(Query query, Weight inner, DoubleValuesSource valueSource, float boost) { |
| + super(query); |
| + this.inner = inner; |
| + this.valueSource = valueSource; |
| + this.boost = boost; |
| + } |
| + |
| + @Override |
| + public void extractTerms(Set<Term> terms) { |
| + this.inner.extractTerms(terms); |
| + } |
| + |
| + @Override |
| + public Explanation explain(LeafReaderContext context, int doc) throws IOException { |
| + Scorer scorer = inner.scorer(context); |
| + if (scorer.iterator().advance(doc) != doc) |
| + return Explanation.noMatch("No match"); |
| + DoubleValues scores = valueSource.getValues(context, DoubleValuesSource.fromScorer(scorer)); |
| + scores.advanceExact(doc); |
| + Explanation scoreExpl = scoreExplanation(context, doc, scores); |
| + if (boost == 1f) |
| + return scoreExpl; |
| + return Explanation.match(scoreExpl.getValue() * boost, "product of:", |
| + Explanation.match(boost, "boost"), scoreExpl); |
| + } |
| + |
| + private Explanation scoreExplanation(LeafReaderContext context, int doc, DoubleValues scores) throws IOException { |
| + if (valueSource.needsScores() == false) |
| + return Explanation.match((float) scores.doubleValue(), valueSource.toString()); |
| + float score = (float) scores.doubleValue(); |
| + return Explanation.match(score, "computed from:", |
| + Explanation.match(score, valueSource.toString()), |
| + inner.explain(context, doc)); |
| + } |
| + |
| + @Override |
| + public Scorer scorer(LeafReaderContext context) throws IOException { |
| + Scorer in = inner.scorer(context); |
| + if (in == null) |
| + return null; |
| + DoubleValues scores = valueSource.getValues(context, DoubleValuesSource.fromScorer(in)); |
| + return new FilterScorer(in) { |
| + @Override |
| + public float score() throws IOException { |
| + if (scores.advanceExact(docID())) |
| + return (float) (scores.doubleValue() * boost); |
| + else |
| + return 0; |
| + } |
| + }; |
| + } |
| + } |
| +} |
| diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionMatchQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionMatchQuery.java |
| new file mode 100644 |
| index 0000000..61faa15 |
| --- /dev/null |
| +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionMatchQuery.java |
| @@ -0,0 +1,61 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.lucene.queries.function; |
| + |
| +import java.io.IOException; |
| + |
| +import org.apache.lucene.index.DirectoryReader; |
| +import org.apache.lucene.index.IndexReader; |
| +import org.apache.lucene.search.DoubleValuesSource; |
| +import org.apache.lucene.search.IndexSearcher; |
| +import org.apache.lucene.search.QueryUtils; |
| +import org.apache.lucene.search.TopDocs; |
| +import org.junit.AfterClass; |
| +import org.junit.BeforeClass; |
| + |
| +public class TestFunctionMatchQuery extends FunctionTestSetup { |
| + |
| + static IndexReader reader; |
| + static IndexSearcher searcher; |
| + |
| + @BeforeClass |
| + public static void beforeClass() throws Exception { |
| + createIndex(true); |
| + reader = DirectoryReader.open(dir); |
| + searcher = new IndexSearcher(reader); |
| + } |
| + |
| + @AfterClass |
| + public static void afterClass() throws Exception { |
| + reader.close(); |
| + } |
| + |
| + public void testRangeMatching() throws IOException { |
| + DoubleValuesSource in = DoubleValuesSource.fromFloatField(FLOAT_FIELD); |
| + FunctionMatchQuery fmq = new FunctionMatchQuery(in, d -> d >= 2 && d < 4); |
| + TopDocs docs = searcher.search(fmq, 10); |
| + |
| + assertEquals(2, docs.totalHits); |
| + assertEquals(9, docs.scoreDocs[0].doc); |
| + assertEquals(13, docs.scoreDocs[1].doc); |
| + |
| + QueryUtils.check(random(), fmq, searcher, rarely()); |
| + |
| + } |
| + |
| +} |
| diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreExplanations.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreExplanations.java |
| new file mode 100644 |
| index 0000000..5c64396 |
| --- /dev/null |
| +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreExplanations.java |
| @@ -0,0 +1,105 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.lucene.queries.function; |
| + |
| +import java.io.IOException; |
| + |
| +import org.apache.lucene.index.Term; |
| +import org.apache.lucene.search.BaseExplanationTestCase; |
| +import org.apache.lucene.search.BooleanClause; |
| +import org.apache.lucene.search.BooleanQuery; |
| +import org.apache.lucene.search.BoostQuery; |
| +import org.apache.lucene.search.DoubleValuesSource; |
| +import org.apache.lucene.search.Explanation; |
| +import org.apache.lucene.search.IndexSearcher; |
| +import org.apache.lucene.search.MatchAllDocsQuery; |
| +import org.apache.lucene.search.Query; |
| +import org.apache.lucene.search.TermQuery; |
| +import org.apache.lucene.search.similarities.BM25Similarity; |
| +import org.apache.lucene.search.similarities.ClassicSimilarity; |
| + |
| +public class TestFunctionScoreExplanations extends BaseExplanationTestCase { |
| + |
| + public void testOneTerm() throws Exception { |
| + Query q = new TermQuery(new Term(FIELD, "w1")); |
| + FunctionScoreQuery fsq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5)); |
| + qtest(fsq, new int[] { 0,1,2,3 }); |
| + } |
| + |
| + public void testBoost() throws Exception { |
| + Query q = new TermQuery(new Term(FIELD, "w1")); |
| + FunctionScoreQuery csq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5)); |
| + qtest(new BoostQuery(csq, 4), new int[] { 0,1,2,3 }); |
| + } |
| + |
| + public void testTopLevelBoost() throws Exception { |
| + Query q = new TermQuery(new Term(FIELD, "w1")); |
| + FunctionScoreQuery csq = new FunctionScoreQuery(q, DoubleValuesSource.constant(5)); |
| + BooleanQuery.Builder bqB = new BooleanQuery.Builder(); |
| + bqB.add(new MatchAllDocsQuery(), BooleanClause.Occur.MUST); |
| + bqB.add(csq, BooleanClause.Occur.MUST); |
| + BooleanQuery bq = bqB.build(); |
| + qtest(new BoostQuery(bq, 6), new int[] { 0,1,2,3 }); |
| + } |
| + |
| + public void testExplanationsIncludingScore() throws Exception { |
| + |
| + DoubleValuesSource scores = DoubleValuesSource.function(DoubleValuesSource.SCORES, v -> v * 2); |
| + |
| + Query q = new TermQuery(new Term(FIELD, "w1")); |
| + FunctionScoreQuery csq = new FunctionScoreQuery(q, scores); |
| + |
| + qtest(csq, new int[] { 0, 1, 2, 3 }); |
| + |
| + Explanation e1 = searcher.explain(q, 0); |
| + Explanation e = searcher.explain(csq, 0); |
| + |
| + assertEquals(e.getDetails().length, 2); |
| + |
| + assertEquals(e1.getValue() * 2, e.getValue(), 0.00001); |
| + } |
| + |
| + public void testSubExplanations() throws IOException { |
| + Query query = new FunctionScoreQuery(new MatchAllDocsQuery(), DoubleValuesSource.constant(5)); |
| + IndexSearcher searcher = newSearcher(BaseExplanationTestCase.searcher.getIndexReader()); |
| + searcher.setSimilarity(new BM25Similarity()); |
| + |
| + Explanation expl = searcher.explain(query, 0); |
| + assertEquals("constant(5.0)", expl.getDescription()); |
| + assertEquals(0, expl.getDetails().length); |
| + |
| + query = new BoostQuery(query, 2); |
| + expl = searcher.explain(query, 0); |
| + assertEquals(2, expl.getDetails().length); |
| + // function |
| + assertEquals(5f, expl.getDetails()[1].getValue(), 0f); |
| + // boost |
| + assertEquals("boost", expl.getDetails()[0].getDescription()); |
| + assertEquals(2f, expl.getDetails()[0].getValue(), 0f); |
| + |
| + searcher.setSimilarity(new ClassicSimilarity()); // in order to have a queryNorm != 1 |
| + expl = searcher.explain(query, 0); |
| + assertEquals(2, expl.getDetails().length); |
| + // function |
| + assertEquals(5f, expl.getDetails()[1].getValue(), 0f); |
| + // boost |
| + assertEquals("boost", expl.getDetails()[0].getDescription()); |
| + assertEquals(2f, expl.getDetails()[0].getValue(), 0f); |
| + } |
| + |
| +} |
| diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java |
| new file mode 100644 |
| index 0000000..8f6ef8e |
| --- /dev/null |
| +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java |
| @@ -0,0 +1,114 @@ |
| +/* |
| + * Licensed to the Apache Software Foundation (ASF) under one or more |
| + * contributor license agreements. See the NOTICE file distributed with |
| + * this work for additional information regarding copyright ownership. |
| + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| + * (the "License"); you may not use this file except in compliance with |
| + * the License. You may obtain a copy of the License at |
| + * |
| + * http://www.apache.org/licenses/LICENSE-2.0 |
| + * |
| + * Unless required by applicable law or agreed to in writing, software |
| + * distributed under the License is distributed on an "AS IS" BASIS, |
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| + * See the License for the specific language governing permissions and |
| + * limitations under the License. |
| + */ |
| + |
| +package org.apache.lucene.queries.function; |
| + |
| +import org.apache.lucene.index.DirectoryReader; |
| +import org.apache.lucene.index.IndexReader; |
| +import org.apache.lucene.index.Term; |
| +import org.apache.lucene.search.BooleanClause; |
| +import org.apache.lucene.search.BooleanQuery; |
| +import org.apache.lucene.search.BoostQuery; |
| +import org.apache.lucene.search.DoubleValuesSource; |
| +import org.apache.lucene.search.IndexSearcher; |
| +import org.apache.lucene.search.Query; |
| +import org.apache.lucene.search.QueryUtils; |
| +import org.apache.lucene.search.TermQuery; |
| +import org.apache.lucene.search.TopDocs; |
| +import org.junit.AfterClass; |
| +import org.junit.BeforeClass; |
| + |
| +public class TestFunctionScoreQuery extends FunctionTestSetup { |
| + |
| + static IndexReader reader; |
| + static IndexSearcher searcher; |
| + |
| + @BeforeClass |
| + public static void beforeClass() throws Exception { |
| + createIndex(true); |
| + reader = DirectoryReader.open(dir); |
| + searcher = new IndexSearcher(reader); |
| + } |
| + |
| + @AfterClass |
| + public static void afterClass() throws Exception { |
| + reader.close(); |
| + } |
| + |
| + // FunctionQuery equivalent |
| + public void testSimpleSourceScore() throws Exception { |
| + |
| + FunctionScoreQuery q = new FunctionScoreQuery(new TermQuery(new Term(TEXT_FIELD, "first")), |
| + DoubleValuesSource.fromIntField(INT_FIELD)); |
| + |
| + QueryUtils.check(random(), q, searcher, rarely()); |
| + |
| + int expectedDocs[] = new int[]{ 4, 7, 9 }; |
| + TopDocs docs = searcher.search(q, 4); |
| + assertEquals(expectedDocs.length, docs.totalHits); |
| + for (int i = 0; i < expectedDocs.length; i++) { |
| + assertEquals(docs.scoreDocs[i].doc, expectedDocs[i]); |
| + } |
| + |
| + } |
| + |
| + // CustomScoreQuery and BoostedQuery equivalent |
| + public void testScoreModifyingSource() throws Exception { |
| + |
| + DoubleValuesSource iii = DoubleValuesSource.fromIntField("iii"); |
| + DoubleValuesSource score = DoubleValuesSource.scoringFunction(iii, (v, s) -> v * s); |
| + |
| + BooleanQuery bq = new BooleanQuery.Builder() |
| + .add(new TermQuery(new Term(TEXT_FIELD, "first")), BooleanClause.Occur.SHOULD) |
| + .add(new TermQuery(new Term(TEXT_FIELD, "text")), BooleanClause.Occur.SHOULD) |
| + .build(); |
| + TopDocs plain = searcher.search(bq, 1); |
| + |
| + FunctionScoreQuery fq = new FunctionScoreQuery(bq, score); |
| + |
| + QueryUtils.check(random(), fq, searcher, rarely()); |
| + |
| + int[] expectedDocs = new int[]{ 4, 7, 9, 8, 12 }; |
| + TopDocs docs = searcher.search(fq, 5); |
| + assertEquals(plain.totalHits, docs.totalHits); |
| + for (int i = 0; i < expectedDocs.length; i++) { |
| + assertEquals(expectedDocs[i], docs.scoreDocs[i].doc); |
| + |
| + } |
| + |
| + } |
| + |
| + // check boosts with non-distributive score source |
| + public void testBoostsAreAppliedLast() throws Exception { |
| + |
| + DoubleValuesSource scores |
| + = DoubleValuesSource.function(DoubleValuesSource.SCORES, v -> Math.log(v + 4)); |
| + |
| + Query q1 = new FunctionScoreQuery(new TermQuery(new Term(TEXT_FIELD, "text")), scores); |
| + TopDocs plain = searcher.search(q1, 5); |
| + |
| + Query boosted = new BoostQuery(q1, 2); |
| + TopDocs afterboost = searcher.search(boosted, 5); |
| + assertEquals(plain.totalHits, afterboost.totalHits); |
| + for (int i = 0; i < 5; i++) { |
| + assertEquals(plain.scoreDocs[i].doc, afterboost.scoreDocs[i].doc); |
| + assertEquals(plain.scoreDocs[i].score, afterboost.scoreDocs[i].score / 2, 0.0001); |
| + } |
| + |
| + } |
| + |
| +} |