blob: 96bd795ef3a8ff4e60773458917c2f0bf289f0f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.response.transform;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreMode;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.CSVFeatureLogger;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.search.LTRQParserPlugin;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.response.transform.DocTransformer;
import org.apache.solr.response.transform.TransformerFactory;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.SolrPluginUtils;
/**
* This transformer will take care to generate and append in the response the
* features declared in the feature store of the current reranking model,
* or a specified feature store. Ex. <code>fl=id,[features store=myStore efi.user_text="ibm"]</code>
*
* <h3>Parameters</h3>
* <code>store</code> - The feature store to extract features from. If not provided it
* will default to the features used by your reranking model.<br>
* <code>efi.*</code> - External feature information variables required by the features
* you are extracting.<br>
* <code>format</code> - The format you want the features to be returned in. Supports (dense|sparse). Defaults to dense.<br>
*/
public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
// used inside fl to specify the format (dense|sparse) of the extracted features
private static final String FV_FORMAT = "format";
// used inside fl to specify the feature store to use for the feature extraction
private static final String FV_STORE = "store";
private static String DEFAULT_LOGGING_MODEL_NAME = "logging-model";
private String fvCacheName;
private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME;
private String defaultStore;
private FeatureLogger.FeatureFormat defaultFormat = FeatureLogger.FeatureFormat.DENSE;
private char csvKeyValueDelimiter = CSVFeatureLogger.DEFAULT_KEY_VALUE_SEPARATOR;
private char csvFeatureSeparator = CSVFeatureLogger.DEFAULT_FEATURE_SEPARATOR;
private LTRThreadModule threadManager = null;
public void setFvCacheName(String fvCacheName) {
this.fvCacheName = fvCacheName;
}
public void setLoggingModelName(String loggingModelName) {
this.loggingModelName = loggingModelName;
}
public void setDefaultStore(String defaultStore) {
this.defaultStore = defaultStore;
}
public void setDefaultFormat(String defaultFormat) {
this.defaultFormat = FeatureLogger.FeatureFormat.valueOf(defaultFormat.toUpperCase(Locale.ROOT));
}
public void setCsvKeyValueDelimiter(String csvKeyValueDelimiter) {
if (csvKeyValueDelimiter.length() != 1) {
throw new IllegalArgumentException("csvKeyValueDelimiter must be exactly 1 character");
}
this.csvKeyValueDelimiter = csvKeyValueDelimiter.charAt(0);
}
public void setCsvFeatureSeparator(String csvFeatureSeparator) {
if (csvFeatureSeparator.length() != 1) {
throw new IllegalArgumentException("csvFeatureSeparator must be exactly 1 character");
}
this.csvFeatureSeparator = csvFeatureSeparator.charAt(0);
}
@Override
@SuppressWarnings({"unchecked"})
public void init(@SuppressWarnings("rawtypes") NamedList args) {
super.init(args);
threadManager = LTRThreadModule.getInstance(args);
SolrPluginUtils.invokeSetters(this, args);
}
@Override
public DocTransformer create(String name, SolrParams localparams,
SolrQueryRequest req) {
// Hint to enable feature vector cache since we are requesting features
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
// Communicate which feature store we are requesting features for
final String fvStoreName = localparams.get(FV_STORE);
SolrQueryRequestContextUtils.setFvStoreName(req, (fvStoreName == null ? defaultStore : fvStoreName));
// Create and supply the feature logger to be used
SolrQueryRequestContextUtils.setFeatureLogger(req,
createFeatureLogger(
localparams.get(FV_FORMAT)));
return new FeatureTransformer(name, localparams, req, (fvStoreName != null) /* hasExplicitFeatureStore */);
}
/**
* returns a FeatureLogger that logs the features
* 'featureFormat' param: 'dense' will write features in dense format,
* 'sparse' will write the features in sparse format, null or empty will
* default to 'sparse'
*
*
* @return a feature logger for the format specified.
*/
private FeatureLogger createFeatureLogger(String formatStr) {
final FeatureLogger.FeatureFormat format;
if (formatStr != null) {
format = FeatureLogger.FeatureFormat.valueOf(formatStr.toUpperCase(Locale.ROOT));
} else {
format = this.defaultFormat;
}
if (fvCacheName == null) {
throw new IllegalArgumentException("a fvCacheName must be configured");
}
return new CSVFeatureLogger(fvCacheName, format, csvKeyValueDelimiter, csvFeatureSeparator);
}
class FeatureTransformer extends DocTransformer {
final private String name;
final private SolrParams localparams;
final private SolrQueryRequest req;
final private boolean hasExplicitFeatureStore;
private List<LeafReaderContext> leafContexts;
private SolrIndexSearcher searcher;
/**
* rerankingQueries, modelWeights have:
* length=1 - [Classic LTR] When reranking with a single model
* length=2 - [Interleaving] When reranking with interleaving (two ranking models are involved)
*/
private LTRScoringQuery[] rerankingQueriesFromContext;
private LTRScoringQuery[] rerankingQueries;
private LTRScoringQuery.ModelWeight[] modelWeights;
private FeatureLogger featureLogger;
private boolean docsWereNotReranked;
/**
* @param name
* Name of the field to be added in a document representing the
* feature vectors
*/
public FeatureTransformer(String name, SolrParams localparams,
SolrQueryRequest req, boolean hasExplicitFeatureStore) {
this.name = name;
this.localparams = localparams;
this.req = req;
this.hasExplicitFeatureStore = hasExplicitFeatureStore;
}
@Override
public String getName() {
return name;
}
@Override
public void setContext(ResultContext context) {
super.setContext(context);
if (context == null) {
return;
}
if (context.getRequest() == null) {
return;
}
searcher = context.getSearcher();
if (searcher == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"searcher is null");
}
leafContexts = searcher.getTopReaderContext().leaves();
if (threadManager != null) {
threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
}
rerankingQueriesFromContext = SolrQueryRequestContextUtils.getScoringQueries(req);
docsWereNotReranked = (rerankingQueriesFromContext == null || rerankingQueriesFromContext.length == 0);
String transformerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req);
Map<String, String[]> transformerExternalFeatureInfo = LTRQParserPlugin.extractEFIParams(localparams);
final LoggingModel loggingModel = createLoggingModel(transformerFeatureStore);
setupRerankingQueriesForLogging(transformerFeatureStore, transformerExternalFeatureInfo, loggingModel);
setupRerankingWeightsForLogging(context);
}
/**
* The loggingModel is an empty model that is just used to extract the features
* and log them
* @param transformerFeatureStore the explicit transformer feature store
*/
private LoggingModel createLoggingModel(String transformerFeatureStore) {
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
final FeatureStore store = fr.getFeatureStore(transformerFeatureStore);
transformerFeatureStore = store.getName(); // if transformerFeatureStore was null before this gets actual name
return new LoggingModel(loggingModelName,
transformerFeatureStore, store.getFeatures());
}
/**
* When preparing the reranking queries for logging features various scenarios apply:
*
* No Reranking
* There is the need of a logger model from the default feature store or the explicit feature store passed
* to extract the feature vector
*
* Re Ranking
* 1) If no explicit feature store is passed, the models for each reranking query can be safely re-used
* the feature vector can be fetched from the feature vector cache.
* 2) If an explicit feature store is passed, and no reranking query uses a model with that feature store,
* There is the need of a logger model to extract the feature vector
* 3) If an explicit feature store is passed, and there is a reranking query that uses a model with that feature store,
* the model can be re-used and there is no need for a logging model
*
* @param transformerFeatureStore explicit feature store for the transformer
* @param transformerExternalFeatureInfo explicit efi for the transformer
*/
private void setupRerankingQueriesForLogging(String transformerFeatureStore, Map<String, String[]> transformerExternalFeatureInfo, LoggingModel loggingModel) {
if (docsWereNotReranked) { //no reranking query
LTRScoringQuery loggingQuery = new LTRScoringQuery(loggingModel,
transformerExternalFeatureInfo,
true /* extractAllFeatures */,
threadManager);
rerankingQueries = new LTRScoringQuery[]{loggingQuery};
} else {
rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length];
System.arraycopy(rerankingQueriesFromContext, 0, rerankingQueries, 0, rerankingQueriesFromContext.length);
if (transformerFeatureStore != null) {// explicit feature store for the transformer
LTRScoringModel matchingRerankingModel = loggingModel;
for (LTRScoringQuery rerankingQuery : rerankingQueries) {
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) &&
transformerFeatureStore.equals(rerankingQuery.getScoringModel().getFeatureStoreName())) {
matchingRerankingModel = rerankingQuery.getScoringModel();
}
}
for (int i = 0; i < rerankingQueries.length; i++) {
rerankingQueries[i] = new LTRScoringQuery(
matchingRerankingModel,
(!transformerExternalFeatureInfo.isEmpty() ? transformerExternalFeatureInfo : rerankingQueries[i].getExternalFeatureInfo()),
true /* extractAllFeatures */,
threadManager);
}
}
}
}
private void setupRerankingWeightsForLogging(ResultContext context) {
modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
for (int i = 0; i < rerankingQueries.length; i++) {
if (rerankingQueries[i].getOriginalQuery() == null) {
rerankingQueries[i].setOriginalQuery(context.getQuery());
}
rerankingQueries[i].setRequest(req);
if (!(rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
if (rerankingQueries[i].getFeatureLogger() == null) {
rerankingQueries[i].setFeatureLogger(SolrQueryRequestContextUtils.getFeatureLogger(req));
}
featureLogger = rerankingQueries[i].getFeatureLogger();
try {
modelWeights[i] = rerankingQueries[i].createWeight(searcher, ScoreMode.COMPLETE, 1f);
} catch (final IOException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
}
if (modelWeights[i] == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"error logging the features, model weight is null");
}
}
}
}
@Override
public void transform(SolrDocument doc, int docid, float score)
throws IOException {
implTransform(doc, docid, score);
}
@Override
public void transform(SolrDocument doc, int docid)
throws IOException {
implTransform(doc, docid, null);
}
private void implTransform(SolrDocument doc, int docid, Float score)
throws IOException {
LTRScoringQuery rerankingQuery = rerankingQueries[0];
LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
for (int i = 1; i < rerankingQueries.length; i++) {
if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
rerankingQuery = rerankingQueries[i];
rerankingModelWeight = modelWeights[i];
}
}
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
if (featureVector == null) { // FV for this document was not in the cache
featureVector = featureLogger.makeFeatureVector(
LTRRescorer.extractFeaturesInfo(
rerankingModelWeight,
docid,
(docsWereNotReranked ? score : null),
leafContexts));
}
doc.addField(name, featureVector);
}
}
}
private static class LoggingModel extends LTRScoringModel {
public LoggingModel(String name, String featureStoreName, List<Feature> allFeatures){
this(name, Collections.emptyList(), Collections.emptyList(),
featureStoreName, allFeatures, Collections.emptyMap());
}
protected LoggingModel(String name, List<Feature> features,
List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return 0;
}
@Override
public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
return Explanation.match(finalScore, toString()
+ " logging model, used only for logging the features");
}
}
}