blob: 85019445546a361f68bbcd9490162310d5b63bbc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.nio.file.Paths;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.solr.SolrTestCase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FieldValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestLTRReRankingPipeline extends SolrTestCase {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(Paths.get("").toAbsolutePath());
private IndexSearcher getSearcher(IndexReader r) {
// 'yes' to maybe wrapping in general
final boolean maybeWrap = true;
final boolean wrapWithAssertions = false;
// 'no' to asserting wrap because lucene AssertingWeight
// cannot be cast to solr LTRScoringQuery$ModelWeight
final IndexSearcher searcher = newSearcher(r, maybeWrap, wrapWithAssertions);
return searcher;
}
private static List<Feature> makeFieldValueFeatures(int[] featureIds,
String field) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
final Map<String,Object> params = new HashMap<String,Object>();
params.put("field", field);
final Feature f = Feature.getInstance(solrResourceLoader,
FieldValueFeature.class.getName(),
"f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
private static class MockModel extends LTRScoringModel {
public MockModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return modelFeatureValuesNormalized[2];
}
@Override
public Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations) {
return null;
}
}
@Test
public void testRescorer() throws IOException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard the the the the the oz",
Field.Store.NO));
doc.add(newStringField("final-score", "F", Field.Store.YES)); // TODO: change to numeric field
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
// 1 extra token, but wizard and oz are close;
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(newStringField("final-score", "T", Field.Store.YES)); // TODO: change to numeric field
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(2, hits.totalHits.value);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
"final-score");
final List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
hits = rescorer.rescore(searcher, hits, 2);
// rerank using the field final-score
assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
r.close();
dir.close();
}
@AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/SOLR-11134")
@Test
public void testDifferentTopN() throws IOException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "2", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 3.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "3", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz the the the the ",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 4.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "4", Field.Store.YES));
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 5.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(5, hits.totalHits.value);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
"final-score");
final List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
// rerank @ 0 should not change the order
hits = rescorer.rescore(searcher, hits, 0);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
// test rerank with different topN cuts
for (int topN = 1; topN <= 5; topN++) {
log.info("rerank {} documents ", topN);
hits = searcher.search(bqBuilder.build(), 10);
final ScoreDoc[] slice = new ScoreDoc[topN];
System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
hits = new TopDocs(hits.totalHits, slice);
hits = rescorer.rescore(searcher, hits, topN);
for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
if (log.isInfoEnabled()) {
log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
.get("id"), j);
}
assertEquals(i,
Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
}
}
r.close();
dir.close();
}
@Test
public void testDocParam() throws Exception {
final Map<String,Object> test = new HashMap<String,Object>();
test.put("fake", 2);
List<Feature> features = makeFieldValueFeatures(new int[] {0},
"final-score");
List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
"final-score");
MockModel ltrScoringModel = new MockModel("test",
features, norms, "test", allFeatures, null);
LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
modelScr.getDocInfo().setOriginalDocScore(1f);
for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
}
features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score");
norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
9}, "final-score");
ltrScoringModel = new MockModel("test", features, norms,
"test", allFeatures, null);
query = new LTRScoringQuery(ltrScoringModel);
wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
modelScr = wgt.scorer(null);
modelScr.getDocInfo().setOriginalDocScore(1f);
for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
}
}
}