blob: e1487590d0a17b7175800237fcab12bb51d7c5b3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model computes scores that can be used to rerank documents.
* <p>
* A scoring model consists of
* <ul>
* <li> a list of features ({@link Feature}) and
* <li> a list of normalizers ({@link Normalizer}) plus
* <li> parameters or configuration to represent the scoring algorithm.
* </ul>
* <p>
* Example configuration (snippet):
* <pre>{
"class" : "...",
"name" : "myModelName",
"features" : [
{
"name" : "isBook"
},
{
"name" : "originalScore",
"norm": {
"class" : "org.apache.solr.ltr.norm.StandardNormalizer",
"params" : { "avg":"100", "std":"10" }
}
},
{
"name" : "price",
"norm": {
"class" : "org.apache.solr.ltr.norm.MinMaxNormalizer",
"params" : { "min":"0", "max":"1000" }
}
}
],
"params" : {
...
}
}</pre>
* <p>
* {@link LTRScoringModel} is an abstract class and concrete classes must
* implement the {@link #score(float[])} and
* {@link #explain(LeafReaderContext, int, float, List)} methods.
*/
public abstract class LTRScoringModel implements Accountable {
private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(LTRScoringModel.class);
protected final String name;
private final String featureStoreName;
protected final List<Feature> features;
private final List<Feature> allFeatures;
private final Map<String,Object> params;
protected final List<Normalizer> norms;
private Integer hashCode; // cached since it shouldn't actually change after construction
@SuppressWarnings({"rawtypes"})
public static LTRScoringModel getInstance(SolrResourceLoader solrResourceLoader,
String className, String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) throws ModelException {
final LTRScoringModel model;
try {
// create an instance of the model
model = solrResourceLoader.newInstance(
className,
LTRScoringModel.class,
new String[0], // no sub packages
new Class[] { String.class, List.class, List.class, String.class, List.class, Map.class },
new Object[] { name, features, norms, featureStoreName, allFeatures, params });
if (params != null) {
SolrPluginUtils.invokeSetters(model, params.entrySet());
}
} catch (final Exception e) {
throw new ModelException("Model loading failed for " + className, e);
}
model.validate();
return model;
}
public LTRScoringModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
this.name = name;
this.features = features != null ? Collections.unmodifiableList(new ArrayList<>(features)) : null;
this.featureStoreName = featureStoreName;
this.allFeatures = allFeatures != null ? Collections.unmodifiableList(new ArrayList<>(allFeatures)) : null;
this.params = params != null ? Collections.unmodifiableMap(new LinkedHashMap<>(params)) : null;
this.norms = norms != null ? Collections.unmodifiableList(new ArrayList<>(norms)) : null;
}
/**
* Validate that settings make sense and throws
* {@link ModelException} if they do not make sense.
*/
protected void validate() throws ModelException {
final List<Feature> features = getFeatures();
final List<Normalizer> norms = getNorms();
if (features.isEmpty()) {
throw new ModelException("no features declared for model "+name);
}
final HashSet<String> featureNames = new HashSet<>();
for (final Feature feature : features) {
final String featureName = feature.getName();
if (!featureNames.add(featureName)) {
throw new ModelException("duplicated feature "+featureName+" in model "+name);
}
}
if (features.size() != norms.size()) {
throw new ModelException("counted "+features.size()+" features and "+norms.size()+" norms in model "+name);
}
}
/**
* @return the norms
*/
public List<Normalizer> getNorms() {
return norms;
}
/**
* @return the name
*/
public String getName() {
return name;
}
/**
* @return the features
*/
public List<Feature> getFeatures() {
return features;
}
public Map<String,Object> getParams() {
return params;
}
@Override
public int hashCode() {
if(hashCode == null) {
hashCode = calculateHashCode();
}
return hashCode;
}
final private int calculateHashCode() {
final int prime = 31;
int result = 1;
result = (prime * result) + Objects.hashCode(features);
result = (prime * result) + Objects.hashCode(name);
result = (prime * result) + Objects.hashCode(params);
result = (prime * result) + Objects.hashCode(norms);
result = (prime * result) + Objects.hashCode(featureStoreName);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final LTRScoringModel other = (LTRScoringModel) obj;
if (features == null) {
if (other.features != null) {
return false;
}
} else if (!features.equals(other.features)) {
return false;
}
if (norms == null) {
if (other.norms != null) {
return false;
}
} else if (!norms.equals(other.norms)) {
return false;
}
if (name == null) {
if (other.name != null) {
return false;
}
} else if (!name.equals(other.name)) {
return false;
}
if (params == null) {
if (other.params != null) {
return false;
}
} else if (!params.equals(other.params)) {
return false;
}
if (featureStoreName == null) {
if (other.featureStoreName != null) {
return false;
}
} else if (!featureStoreName.equals(other.featureStoreName)) {
return false;
}
return true;
}
@Override
public long ramBytesUsed() {
return BASE_RAM_BYTES +
RamUsageEstimator.sizeOfObject(allFeatures, RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED) +
RamUsageEstimator.sizeOfObject(features, RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED) +
RamUsageEstimator.sizeOfObject(featureStoreName) +
RamUsageEstimator.sizeOfObject(name) +
RamUsageEstimator.sizeOfObject(norms) +
RamUsageEstimator.sizeOfObject(params);
}
public Collection<Feature> getAllFeatures() {
return allFeatures;
}
public String getFeatureStoreName() {
return featureStoreName;
}
/**
* Given a list of normalized values for all features a scoring algorithm
* cares about, calculate and return a score.
*
* @param modelFeatureValuesNormalized
* List of normalized feature values. Each feature is identified by
* its id, which is the index in the array
* @return The final score for a document
*/
public abstract float score(float[] modelFeatureValuesNormalized);
/**
* Similar to the score() function, except it returns an explanation of how
* the features were used to calculate the score.
*
* @param context
* Context the document is in
* @param doc
* Document to explain
* @param finalScore
* Original score
* @param featureExplanations
* Explanations for each feature calculation
* @return Explanation for the scoring of a document
*/
public abstract Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations);
@Override
public String toString() {
return getClass().getSimpleName() + "(name="+getName()+")";
}
/**
* Goes through all the stored feature values, and calculates the normalized
* values for all the features that will be used for scoring.
*/
public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
float[] modelFeatureValuesNormalized = modelFeatureValues;
if (modelFeatureValues.length != norms.size()) {
throw new FeatureException("Must have normalizer for every feature");
}
for(int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
modelFeatureValuesNormalized[idx] =
norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
}
}
public Explanation getNormalizerExplanation(Explanation e, int idx) {
Normalizer n = norms.get(idx);
if (n != IdentityNormalizer.INSTANCE) {
return n.explain(e);
}
return e;
}
}