blob: d79db8382d71324f1ddc404b1086e55354b7232b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model that computes scores based on the summation of multiple weighted trees.
* Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
* <p>
* Example configuration:
<pre>{
"class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name" : "multipleadditivetreesmodel",
"features":[
{ "name" : "userTextTitleMatch"},
{ "name" : "originalScore"}
],
"params" : {
"trees" : [
{
"weight" : "1",
"root": {
"feature" : "userTextTitleMatch",
"threshold" : "0.5",
"left" : {
"value" : "-100"
},
"right" : {
"feature" : "originalScore",
"threshold" : "10.0",
"left" : {
"value" : "50"
},
"right" : {
"value" : "75"
}
}
}
},
{
"weight" : "2",
"root" : {
"value" : "-10"
}
}
]
}
}</pre>
* <p>
* Training libraries:
* <ul>
* <li> <a href="http://sourceforge.net/p/lemur/wiki/RankLib/">RankLib</a>
* </ul>
* <p>
* Background reading:
* <ul>
* <li> <a href="http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf">
* Christopher J.C. Burges. From RankNet to LambdaRank to LambdaMART: An Overview.
* Microsoft Research Technical Report MSR-TR-2010-82.</a>
* </ul>
* <ul>
* <li> <a href="https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf">
* Z. Zheng, H. Zha, T. Zhang, O. Chapelle, K. Chen, and G. Sun. A General Boosting Method and its Application to Learning Ranking Functions for Web Search.
* Advances in Neural Information Processing Systems (NIPS), 2007.</a>
* </ul>
*/
public class MultipleAdditiveTreesModel extends LTRScoringModel {
/**
* fname2index is filled from constructor arguments
* (that are already part of the base class hashCode)
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private final HashMap<String,Integer> fname2index;
/**
* trees is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private List<RegressionTree> trees;
private RegressionTree createRegressionTree(Map<String,Object> map) {
final RegressionTree rt = new RegressionTree();
if (map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
}
return rt;
}
private RegressionTreeNode createRegressionTreeNode(Map<String,Object> map) {
final RegressionTreeNode rtn = new RegressionTreeNode();
if (map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
}
return rtn;
}
public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1E-6f;
private float value = 0f;
private String feature;
private int featureIndex = -1;
private Float threshold;
private RegressionTreeNode left;
private RegressionTreeNode right;
public void setValue(float value) {
this.value = value;
}
public void setValue(String value) {
this.value = Float.parseFloat(value);
}
public void setFeature(String feature) {
this.feature = feature;
final Integer idx = fname2index.get(this.feature);
// this happens if the tree specifies a feature that does not exist
// this could be due to lambdaSmart building off of pre-existing trees
// that use a feature that is no longer output during feature extraction
featureIndex = (idx == null) ? -1 : idx;
}
public void setThreshold(float threshold) {
this.threshold = threshold + NODE_SPLIT_SLACK;
}
public void setThreshold(String threshold) {
this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
}
@SuppressWarnings({"unchecked"})
public void setLeft(Object left) {
this.left = createRegressionTreeNode((Map<String,Object>) left);
}
@SuppressWarnings({"unchecked"})
public void setRight(Object right) {
this.right = createRegressionTreeNode((Map<String,Object>) right);
}
public boolean isLeaf() {
return feature == null;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
if (isLeaf()) {
sb.append(value);
} else {
sb.append("(feature=").append(feature);
sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK);
sb.append(",left=").append(left);
sb.append(",right=").append(right);
sb.append(')');
}
return sb.toString();
}
public RegressionTreeNode() {
}
}
public class RegressionTree {
private Float weight;
private RegressionTreeNode root;
public void setWeight(float weight) {
this.weight = weight;
}
public void setWeight(String weight) {
this.weight = Float.valueOf(weight);
}
@SuppressWarnings({"unchecked"})
public void setRoot(Object root) {
this.root = createRegressionTreeNode((Map<String,Object>)root);
}
public float score(float[] featureVector) {
return weight.floatValue() * scoreNode(featureVector, root);
}
public String explain(float[] featureVector) {
return explainNode(featureVector, root);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(weight);
sb.append(",root=").append(root);
sb.append(")");
return sb.toString();
}
public RegressionTree() {
}
public void validate() throws ModelException {
if (weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
}
if (root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
} else {
validateNode(root);
}
}
}
@SuppressWarnings({"unchecked"})
public void setTrees(Object trees) {
this.trees = new ArrayList<RegressionTree>();
for (final Object o : (List<Object>) trees) {
final RegressionTree rt = createRegressionTree((Map<String,Object>) o);
this.trees.add(rt);
}
}
public MultipleAdditiveTreesModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
fname2index = new HashMap<String,Integer>();
for (int i = 0; i < features.size(); ++i) {
final String key = features.get(i).getName();
fname2index.put(key, i);
}
}
@Override
protected void validate() throws ModelException {
super.validate();
if (trees == null) {
throw new ModelException("no trees declared for model "+name);
}
for (RegressionTree tree : trees) {
tree.validate();
}
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
float score = 0;
for (final RegressionTree t : trees) {
score += t.score(modelFeatureValuesNormalized);
}
return score;
}
private static float scoreNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
while (true) {
if (regressionTreeNode.isLeaf()) {
return regressionTreeNode.value;
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((regressionTreeNode.featureIndex < 0) || (regressionTreeNode.featureIndex >= featureVector.length)) {
return 0f;
}
if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
regressionTreeNode = regressionTreeNode.left;
} else {
regressionTreeNode = regressionTreeNode.right;
}
}
}
private static void validateNode(RegressionTreeNode regressionTreeNode) throws ModelException {
// Create an empty stack and push root to it
Stack<RegressionTreeNode> stack = new Stack<RegressionTreeNode>();
stack.push(regressionTreeNode);
while (stack.empty() == false) {
RegressionTreeNode topStackNode = stack.pop();
if (topStackNode.isLeaf()) {
if (topStackNode.left != null || topStackNode.right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + topStackNode.left + " and right=" + topStackNode.right);
}
continue;
}
if (null == topStackNode.threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
}
if (null == topStackNode.left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
stack.push(topStackNode.left);
}
if (null == topStackNode.right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
stack.push(topStackNode.right);
}
}
}
private static String explainNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
final StringBuilder returnValueBuilder = new StringBuilder();
while (true) {
if (regressionTreeNode.isLeaf()) {
returnValueBuilder.append("val: " + regressionTreeNode.value);
return returnValueBuilder.toString();
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((regressionTreeNode.featureIndex < 0) || (regressionTreeNode.featureIndex >= featureVector.length)) {
returnValueBuilder.append("'" + regressionTreeNode.feature + "' does not exist in FV, Return Zero");
return returnValueBuilder.toString();
}
// could store extra information about how much training data supported
// each branch and report
// that here
if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
returnValueBuilder.append("'" + regressionTreeNode.feature + "':" + featureVector[regressionTreeNode.featureIndex] + " <= "
+ regressionTreeNode.threshold + ", Go Left | ");
regressionTreeNode = regressionTreeNode.left;
} else {
returnValueBuilder.append("'" + regressionTreeNode.feature + "':" + featureVector[regressionTreeNode.featureIndex] + " > "
+ regressionTreeNode.threshold + ", Go Right | ");
regressionTreeNode = regressionTreeNode.right;
}
}
}
// /////////////////////////////////////////
// produces a string that looks like:
// 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ]
// model applied to
// features, sum of:
// 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right |
// 'this_feature_doesnt_exist' does not
// exist in FV, Go Left | val: 50.0
// -10.0 = tree 1 | val: -10.0
@Override
public Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations) {
final float[] fv = new float[featureExplanations.size()];
int index = 0;
for (final Explanation featureExplain : featureExplanations) {
fv[index] = featureExplain.getValue().floatValue();
index++;
}
final List<Explanation> details = new ArrayList<>();
index = 0;
for (final RegressionTree t : trees) {
final float score = t.score(fv);
final Explanation p = Explanation.match(score, "tree " + index + " | "
+ t.explain(fv));
details.add(p);
index++;
}
return Explanation.match(finalScore, toString()
+ " model applied to features, sum of:", details);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
sb.append("(name=").append(getName());
sb.append(",trees=[");
for (int ii = 0; ii < trees.size(); ++ii) {
if (ii>0) {
sb.append(',');
}
sb.append(trees.get(ii));
}
sb.append("])");
return sb.toString();
}
}