blob: bf4c19b40634d7144ef42107e19d4a5e3d4800b4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import com.carrotsearch.hppc.IntFloatHashMap;
import com.carrotsearch.hppc.IntIntHashMap;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.handler.component.QueryElevationComponent;
import org.apache.solr.request.SolrRequestInfo;
/* A TopDocsCollector used by reranking queries. */
public class ReRankCollector extends TopDocsCollector<ScoreDoc> {
private final TopDocsCollector<? extends ScoreDoc> mainCollector;
private final IndexSearcher searcher;
private final int reRankDocs;
private final int length;
private final Set<BytesRef> boostedPriority; // order is the "priority"
private final Rescorer reRankQueryRescorer;
private final Sort sort;
private final Query query;
private ReRankScaler reRankScaler;
private ReRankOperator reRankOperator;
public ReRankCollector(
int reRankDocs,
int length,
Rescorer reRankQueryRescorer,
QueryCommand cmd,
IndexSearcher searcher,
Set<BytesRef> boostedPriority,
ReRankScaler reRankScaler,
ReRankOperator reRankOperator)
throws IOException {
this(reRankDocs, length, reRankQueryRescorer, cmd, searcher, boostedPriority);
this.reRankScaler = reRankScaler;
this.reRankOperator = reRankOperator;
}
public ReRankCollector(
int reRankDocs,
int length,
Rescorer reRankQueryRescorer,
QueryCommand cmd,
IndexSearcher searcher,
Set<BytesRef> boostedPriority)
throws IOException {
super(null);
this.reRankDocs = reRankDocs;
this.length = length;
this.boostedPriority = boostedPriority;
this.query = cmd.getQuery();
Sort sort = cmd.getSort();
if (sort == null) {
this.sort = null;
this.mainCollector =
TopScoreDocCollector.create(Math.max(this.reRankDocs, length), cmd.getMinExactCount());
} else {
this.sort = sort = sort.rewrite(searcher);
// scores are needed for Rescorer (regardless of whether sort needs it)
this.mainCollector =
TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), cmd.getMinExactCount());
}
this.searcher = searcher;
this.reRankQueryRescorer = reRankQueryRescorer;
}
@Override
public int getTotalHits() {
return mainCollector.getTotalHits();
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return mainCollector.getLeafCollector(context);
}
@Override
public ScoreMode scoreMode() {
return this.mainCollector.scoreMode();
}
@Override
public TopDocs topDocs(int start, int howMany) {
try {
TopDocs mainDocs = mainCollector.topDocs(0, Math.max(reRankDocs, length));
if (mainDocs.totalHits.value == 0 || mainDocs.scoreDocs.length == 0) {
return mainDocs;
}
if (sort != null) {
TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query);
}
ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
boolean zeroOutScores = reRankScaler != null && reRankScaler.scaleScores();
ScoreDoc[] mainScoreDocsClone = deepClone(mainScoreDocs, zeroOutScores);
ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)];
System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length);
mainDocs.scoreDocs = reRankScoreDocs;
// If we're scaling scores use the replace rescorer because we just want the re-rank score.
TopDocs rescoredDocs;
try {
rescoredDocs =
zeroOutScores // previously zero-ed out scores are to be replaced
? reRankScaler
.getReplaceRescorer()
.rescore(searcher, mainDocs, mainDocs.scoreDocs.length)
: reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
} catch (IncompleteRerankingException ex) {
mainDocs.scoreDocs = mainScoreDocsClone;
rescoredDocs = mainDocs;
}
// Lower howMany to return if we've collected fewer documents.
howMany = Math.min(howMany, mainScoreDocs.length);
if (boostedPriority != null) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
Map<Object, Object> requestContext = null;
if (info != null) {
requestContext = info.getReq().getContext();
}
IntIntHashMap boostedDocs =
QueryElevationComponent.getBoostDocs(
(SolrIndexSearcher) searcher, boostedPriority, requestContext);
float maxScore =
rescoredDocs.scoreDocs.length == 0 ? Float.NaN : rescoredDocs.scoreDocs[0].score;
Arrays.sort(
rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, maxScore));
}
if (howMany == rescoredDocs.scoreDocs.length) {
if (reRankScaler != null && reRankScaler.scaleScores()) {
rescoredDocs.scoreDocs =
reRankScaler.scaleScores(
mainScoreDocsClone, rescoredDocs.scoreDocs, reRankScoreDocs.length);
}
return rescoredDocs; // Just return the rescoredDocs
} else if (howMany > rescoredDocs.scoreDocs.length) {
// We need to return more then we've reRanked, so create the combined page.
ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
System.arraycopy(
mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); // lay down the initial docs
System.arraycopy(
rescoredDocs.scoreDocs,
0,
scoreDocs,
0,
rescoredDocs.scoreDocs.length); // overlay the re-ranked docs.
rescoredDocs.scoreDocs = scoreDocs;
if (reRankScaler != null && reRankScaler.scaleScores()) {
rescoredDocs.scoreDocs =
reRankScaler.scaleScores(
mainScoreDocsClone, rescoredDocs.scoreDocs, reRankScoreDocs.length);
}
return rescoredDocs;
} else {
// We've rescored more then we need to return.
if (reRankScaler != null && reRankScaler.scaleScores()) {
rescoredDocs.scoreDocs =
reRankScaler.scaleScores(
mainScoreDocsClone, rescoredDocs.scoreDocs, rescoredDocs.scoreDocs.length);
}
ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany);
rescoredDocs.scoreDocs = scoreDocs;
return rescoredDocs;
}
} catch (Exception e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
private ScoreDoc[] deepClone(ScoreDoc[] scoreDocs, boolean zeroOut) {
ScoreDoc[] scoreDocs1 = new ScoreDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
if (scoreDoc != null) {
scoreDocs1[i] = new ScoreDoc(scoreDoc.doc, scoreDoc.score);
if (zeroOut) {
scoreDoc.score = 0f;
}
}
}
return scoreDocs1;
}
public static class BoostedComp implements Comparator<ScoreDoc> {
IntFloatHashMap boostedMap;
public BoostedComp(IntIntHashMap boostedDocs, ScoreDoc[] scoreDocs, float maxScore) {
this.boostedMap = new IntFloatHashMap(boostedDocs.size() * 2);
for (int i = 0; i < scoreDocs.length; i++) {
final int idx;
if ((idx = boostedDocs.indexOf(scoreDocs[i].doc)) >= 0) {
boostedMap.put(scoreDocs[i].doc, maxScore + boostedDocs.indexGet(idx));
} else {
break;
}
}
}
@Override
public int compare(ScoreDoc doc1, ScoreDoc doc2) {
float score1 = doc1.score;
float score2 = doc2.score;
int idx;
if ((idx = boostedMap.indexOf(doc1.doc)) >= 0) {
score1 = boostedMap.indexGet(idx);
}
if ((idx = boostedMap.indexOf(doc2.doc)) >= 0) {
score2 = boostedMap.indexGet(idx);
}
return -Float.compare(score1, score2);
}
}
}