blob: 7db845fdd0f65f38a13146b135684b2f2c97a6ff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.SolrIndexSearcher;
/**
* Implements the rescoring logic. The top documents returned by solr with their
* original scores, will be processed by a {@link LTRScoringQuery} that will assign a
* new score to each document. The top documents will be resorted based on the
* new score.
* */
public class LTRRescorer extends Rescorer {
final private LTRScoringQuery scoringQuery;
public LTRRescorer() {
this.scoringQuery = null;
}
public LTRRescorer(LTRScoringQuery scoringQuery) {
this.scoringQuery = scoringQuery;
}
final private static Comparator<ScoreDoc> docComparator = Comparator.comparingInt(a -> a.doc);
final protected static Comparator<ScoreDoc> scoreComparator = (a, b) -> {
// Sort by score descending, then docID ascending:
if (a.score > b.score) {
return -1;
} else if (a.score < b.score) {
return 1;
} else {
// This subtraction can't overflow int
// because docIDs are >= 0:
return a.doc - b.doc;
}
};
protected static void heapAdjust(ScoreDoc[] hits, int size, int root) {
final ScoreDoc doc = hits[root];
final float score = doc.score;
int i = root;
while (i <= ((size >> 1) - 1)) {
final int lchild = (i << 1) + 1;
final ScoreDoc ldoc = hits[lchild];
final float lscore = ldoc.score;
float rscore = Float.MAX_VALUE;
final int rchild = (i << 1) + 2;
ScoreDoc rdoc = null;
if (rchild < size) {
rdoc = hits[rchild];
rscore = rdoc.score;
}
if (lscore < score) {
if (rscore < lscore) {
hits[i] = rdoc;
hits[rchild] = doc;
i = rchild;
} else {
hits[i] = ldoc;
hits[lchild] = doc;
i = lchild;
}
} else if (rscore < score) {
hits[i] = rdoc;
hits[rchild] = doc;
i = rchild;
} else {
return;
}
}
}
protected static void heapify(ScoreDoc[] hits, int size) {
for (int i = (size >> 1) - 1; i >= 0; i--) {
heapAdjust(hits, size, i);
}
}
/**
* rescores the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
* @param topN
* documents to return;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
return firstPassTopDocs;
}
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
return new TopDocs(firstPassTopDocs.totalHits, reranked);
}
private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);
scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked);
// Must sort all documents that we reranked, and then select the top
Arrays.sort(reranked, scoreComparator);
return reranked;
}
@Deprecated
protected static void sortByScore(ScoreDoc[] reranked) {
Arrays.sort(reranked, scoreComparator);
}
protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
Arrays.sort(hits, docComparator);
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
return hits;
}
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {
int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
int hitUpto = 0;
while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}
// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext);
}
if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) {
logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
}
hitUpto++;
}
}
/**
* @deprecated Use {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])}
* and {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)} instead.
*/
@Deprecated
protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) {
logSingleHit(indexSearcher, modelWeight, hit.doc, rerankingQuery);
}
}
/**
* Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])}
* method indicated that the document's feature info should be logged.
*/
protected static void logSingleHit(IndexSearcher indexSearcher, LTRScoringQuery.ModelWeight modelWeight, int docid, LTRScoringQuery scoringQuery) {
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
featureLogger.log(docid, scoringQuery, (SolrIndexSearcher)indexSearcher, modelWeight.getFeaturesInfo());
}
}
/**
* Scores a single document and returns true if the document's feature info should be logged via the
* {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, LTRScoringQuery)}
* method. Feature info logging is only necessary for the topN documents.
*/
protected static boolean scoreSingleHit(int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
// even if no feature scorers match, since a model might use that info to
// return a
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);
boolean logHit = false;
scorer.getDocInfo().setOriginalDocScore(hit.score);
hit.score = scorer.score();
if (hitUpto < topN) {
reranked[hitUpto] = hit;
// if the heap is not full, maybe I want to log the features for this
// document
logHit = true;
} else if (hitUpto == topN) {
// collected topN document, I create the heap
heapify(reranked, topN);
}
if (hitUpto >= topN) {
// once that heap is ready, if the score of this document is lower that
// the minimum
// i don't want to log the feature. Otherwise I replace it with the
// minimum and fix the
// heap.
if (hit.score > reranked[0].score) {
reranked[0] = hit;
heapAdjust(reranked, topN, 0);
logHit = true;
}
}
return logHit;
}
@Override
public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException {
return getExplanation(searcher, docID, scoringQuery);
}
protected static Explanation getExplanation(IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException {
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
.leaves();
final int n = ReaderUtil.subIndex(docID, leafContexts);
final LeafReaderContext context = leafContexts.get(n);
final int deBasedDoc = docID - context.docBase;
final Weight rankingWeight;
if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) {
rankingWeight = rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1);
} else {
rankingWeight = searcher.createWeight(searcher.rewrite(rerankingQuery),
ScoreMode.COMPLETE, 1);
}
return rankingWeight.explain(context, deBasedDoc);
}
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,
int docid,
Float originalDocScore,
List<LeafReaderContext> leafContexts)
throws IOException {
final int n = ReaderUtil.subIndex(docid, leafContexts);
final LeafReaderContext atomicContext = leafContexts.get(n);
final int deBasedDoc = docid - atomicContext.docBase;
final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext);
if ( (r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc) ) {
return new LTRScoringQuery.FeatureInfo[0];
} else {
if (originalDocScore != null) {
// If results have not been reranked, the score passed in is the original query's
// score, which some features can use instead of recalculating it
r.getDocInfo().setOriginalDocScore(originalDocScore);
}
r.score();
return modelWeight.getFeaturesInfo();
}
}
}