blob: 4e397a55f32e68b0ae045eff43ad29c8b0f31f12 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.File;
import java.lang.invoke.MethodHandles;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.commons.io.FileUtils;
import org.apache.solr.common.util.Utils;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.apache.solr.util.RestTestBase;
import org.eclipse.jetty.servlet.ServletHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestRerankBase extends RestTestBase {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
protected static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(Paths.get("").toAbsolutePath());
protected static File tmpSolrHome;
protected static File tmpConfDir;
public static final String FEATURE_FILE_NAME = "_schema_feature-store.json";
public static final String MODEL_FILE_NAME = "_schema_model-store.json";
public static final String PARENT_ENDPOINT = "/schema/*";
protected static final String COLLECTION = "collection1";
protected static final String CONF_DIR = COLLECTION + "/conf";
protected static File fstorefile = null;
protected static File mstorefile = null;
final private static String SYSTEM_PROPERTY_SOLR_LTR_TRANSFORMER_FV_DEFAULTFORMAT = "solr.ltr.transformer.fv.defaultFormat";
private static String defaultFeatureFormat;
protected String chooseDefaultFeatureVector(String dense, String sparse) {
if (defaultFeatureFormat == null) {
// to match <code><str name="defaultFormat">${solr.ltr.transformer.fv.defaultFormat:dense}</str></code> snippet
return dense;
} else if ("dense".equals(defaultFeatureFormat)) {
return dense;
} else if ("sparse".equals(defaultFeatureFormat)) {
return sparse;
} else {
fail("unexpected feature format choice: "+defaultFeatureFormat);
return null;
}
}
protected static void chooseDefaultFeatureFormat() throws Exception {
switch (random().nextInt(3)) {
case 0:
defaultFeatureFormat = null;
break;
case 1:
defaultFeatureFormat = "dense";
break;
case 2:
defaultFeatureFormat = "sparse";
break;
default:
fail("unexpected feature format choice");
break;
}
if (defaultFeatureFormat != null) {
System.setProperty(SYSTEM_PROPERTY_SOLR_LTR_TRANSFORMER_FV_DEFAULTFORMAT, defaultFeatureFormat);
}
}
protected static void unchooseDefaultFeatureFormat() {
System.clearProperty(SYSTEM_PROPERTY_SOLR_LTR_TRANSFORMER_FV_DEFAULTFORMAT);
}
protected static void setuptest(boolean bulkIndex) throws Exception {
chooseDefaultFeatureFormat();
setuptest("solrconfig-ltr.xml", "schema.xml");
if (bulkIndex) bulkIndex();
}
protected static void setupPersistenttest(boolean bulkIndex) throws Exception {
chooseDefaultFeatureFormat();
setupPersistentTest("solrconfig-ltr.xml", "schema.xml");
if (bulkIndex) bulkIndex();
}
public static ManagedFeatureStore getManagedFeatureStore() {
try (SolrCore core = jetty.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) {
return ManagedFeatureStore.getManagedFeatureStore(core);
}
}
public static ManagedModelStore getManagedModelStore() {
try (SolrCore core = jetty.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) {
return ManagedModelStore.getManagedModelStore(core);
}
}
protected static SortedMap<ServletHolder,String> setupTestInit(
String solrconfig, String schema,
boolean isPersistent) throws Exception {
tmpSolrHome = createTempDir().toFile();
tmpConfDir = new File(tmpSolrHome, CONF_DIR);
tmpConfDir.deleteOnExit();
FileUtils.copyDirectory(new File(TEST_HOME()),
tmpSolrHome.getAbsoluteFile());
final File fstore = new File(tmpConfDir, FEATURE_FILE_NAME);
final File mstore = new File(tmpConfDir, MODEL_FILE_NAME);
if (isPersistent) {
fstorefile = fstore;
mstorefile = mstore;
}
if (fstore.exists()) {
if (log.isInfoEnabled()) {
log.info("remove feature store config file in {}", fstore.getAbsolutePath());
}
Files.delete(fstore.toPath());
}
if (mstore.exists()) {
if (log.isInfoEnabled()) {
log.info("remove model store config file in {}", mstore.getAbsolutePath());
}
Files.delete(mstore.toPath());
}
if (!solrconfig.equals("solrconfig.xml")) {
FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/" + solrconfig),
new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/solrconfig.xml"));
}
if (!schema.equals("schema.xml")) {
FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/" + schema),
new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/schema.xml"));
}
System.setProperty("managed.schema.mutable", "true");
final SortedMap<ServletHolder,String> extraServlets = new TreeMap<>();
return extraServlets;
}
public static void setuptest(String solrconfig, String schema)
throws Exception {
SortedMap<ServletHolder,String> extraServlets =
setupTestInit(solrconfig,schema,false);
System.setProperty("enable.update.log", "false");
createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
"/solr", true, extraServlets);
}
public static void setupPersistentTest(String solrconfig, String schema)
throws Exception {
SortedMap<ServletHolder,String> extraServlets =
setupTestInit(solrconfig,schema,true);
createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
"/solr", true, extraServlets);
}
protected static void aftertest() throws Exception {
if (null != restTestHarness) {
restTestHarness.close();
restTestHarness = null;
}
if (null != jetty) {
jetty.stop();
jetty = null;
}
if (null != tmpSolrHome) {
FileUtils.deleteDirectory(tmpSolrHome);
tmpSolrHome = null;
}
System.clearProperty("managed.schema.mutable");
// System.clearProperty("enable.update.log");
unchooseDefaultFeatureFormat();
}
public static void makeRestTestHarnessNull() {
restTestHarness = null;
}
/** produces a model encoded in json **/
public static String getModelInJson(String name, String type,
String[] features, String fstore, String params) {
final StringBuilder sb = new StringBuilder();
sb.append("{\n");
sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
sb.append("\"store\":").append('"').append(fstore).append('"')
.append(",\n");
sb.append("\"class\":").append('"').append(type).append('"').append(",\n");
sb.append("\"features\":").append('[');
if (features.length > 0) {
for (final String feature : features) {
sb.append("\n\t{ ");
sb.append("\"name\":").append('"').append(feature).append('"')
.append("},");
}
sb.deleteCharAt(sb.length() - 1);
}
sb.append("\n]\n");
if (params != null) {
sb.append(",\n");
sb.append("\"params\":").append(params);
}
sb.append("\n}\n");
return sb.toString();
}
/** produces a model encoded in json **/
public static String getFeatureInJson(String name, String type,
String fstore, String params) {
final StringBuilder sb = new StringBuilder();
sb.append("{\n");
sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
sb.append("\"store\":").append('"').append(fstore).append('"')
.append(",\n");
sb.append("\"class\":").append('"').append(type).append('"');
if (params != null) {
sb.append(",\n");
sb.append("\"params\":").append(params);
}
sb.append("\n}\n");
return sb.toString();
}
protected static void loadFeature(String name, String type, String params)
throws Exception {
final String feature = getFeatureInJson(name, type, "test", params);
log.info("loading feauture \n{} ", feature);
assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
"/responseHeader/status==0");
}
protected static void loadFeature(String name, String type, String fstore,
String params) throws Exception {
final String feature = getFeatureInJson(name, type, fstore, params);
log.info("loading feauture \n{} ", feature);
assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
"/responseHeader/status==0");
}
protected static void loadModel(String name, String type, String[] features,
String params) throws Exception {
loadModel(name, type, features, "test", params);
}
protected static void loadModel(String name, String type, String[] features,
String fstore, String params) throws Exception {
final String model = getModelInJson(name, type, features, fstore, params);
log.info("loading model \n{} ", model);
assertJPut(ManagedModelStore.REST_END_POINT, model,
"/responseHeader/status==0");
}
public static void loadModels(String fileName) throws Exception {
final URL url = TestRerankBase.class.getResource("/modelExamples/"
+ fileName);
final String multipleModels = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
assertJPut(ManagedModelStore.REST_END_POINT, multipleModels,
"/responseHeader/status==0");
}
public static LTRScoringModel createModelFromFiles(String modelFileName,
String featureFileName) throws ModelException, Exception {
return createModelFromFiles(modelFileName, featureFileName,
FeatureStore.DEFAULT_FEATURE_STORE_NAME);
}
public static LTRScoringModel createModelFromFiles(String modelFileName,
String featureFileName, String featureStoreName) throws ModelException, Exception {
URL url = TestRerankBase.class.getResource("/modelExamples/"
+ modelFileName);
final String modelJson = FileUtils.readFileToString(new File(url.toURI()),
"UTF-8");
final ManagedModelStore ms = getManagedModelStore();
url = TestRerankBase.class.getResource("/featureExamples/"
+ featureFileName);
final String featureJson = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
Object parsedFeatureJson = null;
try {
parsedFeatureJson = Utils.fromJSONString(featureJson);
} catch (final Exception ioExc) {
throw new ModelException("ObjectBuilder failed parsing json", ioExc);
}
final ManagedFeatureStore fs = getManagedFeatureStore();
// fs.getFeatureStore(null).clear();
fs.doDeleteChild(null, featureStoreName); // is this safe??
// based on my need to call this I dont think that
// "getNewManagedFeatureStore()"
// is actually returning a new feature store each time
fs.applyUpdatesToManagedData(parsedFeatureJson);
ms.setManagedFeatureStore(fs); // can we skip this and just use fs directly below?
final LTRScoringModel ltrScoringModel = ManagedModelStore.fromLTRScoringModelMap(
solrResourceLoader, mapFromJson(modelJson), ms.getManagedFeatureStore());
ms.addModel(ltrScoringModel);
return ltrScoringModel;
}
@SuppressWarnings("unchecked")
static private Map<String,Object> mapFromJson(String json) throws ModelException {
Object parsedJson = null;
try {
parsedJson = Utils.fromJSONString(json);
} catch (final Exception ioExc) {
throw new ModelException("ObjectBuilder failed parsing json", ioExc);
}
return (Map<String,Object>) parsedJson;
}
public static void loadFeatures(String fileName) throws Exception {
final URL url = TestRerankBase.class.getResource("/featureExamples/"
+ fileName);
final String multipleFeatures = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
log.info("send \n{}", multipleFeatures);
assertJPut(ManagedFeatureStore.REST_END_POINT, multipleFeatures,
"/responseHeader/status==0");
}
protected List<Feature> getFeatures(List<String> names)
throws FeatureException {
final List<Feature> features = new ArrayList<>();
int pos = 0;
for (final String name : names) {
final Map<String,Object> params = new HashMap<String,Object>();
params.put("value", 10);
final Feature f = Feature.getInstance(solrResourceLoader,
ValueFeature.class.getCanonicalName(),
name, params);
f.setIndex(pos);
features.add(f);
++pos;
}
return features;
}
protected List<Feature> getFeatures(String[] names) throws FeatureException {
return getFeatures(Arrays.asList(names));
}
protected static void bulkIndex() throws Exception {
assertU(adoc("title", "bloomberg different bla", "description",
"bloomberg", "id", "6", "popularity", "1"));
assertU(adoc("title", "bloomberg bloomberg ", "description", "bloomberg",
"id", "7", "popularity", "2"));
assertU(adoc("title", "bloomberg bloomberg bloomberg", "description",
"bloomberg", "id", "8", "popularity", "3"));
assertU(adoc("title", "bloomberg bloomberg bloomberg bloomberg",
"description", "bloomberg", "id", "9", "popularity", "5"));
assertU(commit());
}
protected static void doTestParamsToMap(String featureClassName,
LinkedHashMap<String,Object> featureParams) throws Exception {
// start with default parameters
final LinkedHashMap<String,Object> paramsA = new LinkedHashMap<String,Object>();
final Object defaultValue;
switch (random().nextInt(6)) {
case 0:
defaultValue = null;
break;
case 1:
defaultValue = "1.2";
break;
case 2:
defaultValue = Double.valueOf(3.4d);
break;
case 3:
defaultValue = Float.valueOf(0.5f);
break;
case 4:
defaultValue = Integer.valueOf(67);
break;
case 5:
defaultValue = Long.valueOf(89);
break;
default:
defaultValue = null;
fail("unexpected defaultValue choice");
break;
}
if (defaultValue != null) {
paramsA.put("defaultValue", defaultValue);
}
// then add specific parameters
paramsA.putAll(featureParams);
// next choose a random feature name
final String featureName = "randomFeatureName"+random().nextInt(10);
// create a feature from the parameters
final Feature featureA = Feature.getInstance(solrResourceLoader,
featureClassName, featureName, paramsA);
// turn the feature back into parameters
final LinkedHashMap<String,Object> paramsB = featureA.paramsToMap();
// create feature B from feature A's parameters
final Feature featureB = Feature.getInstance(solrResourceLoader,
featureClassName, featureName, paramsB);
// check that feature A and feature B are identical
assertEquals(featureA, featureB);
}
}