blob: 4afb7b7636aa78cea64efff71356a0cb4e6920b8 [file] [log] [blame]
/* * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.SortedMap;
import java.util.stream.IntStream;
import org.apache.commons.io.FileUtils;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.embedded.JettyConfig;
import org.apache.solr.client.solrj.embedded.JettySolrRunner;
import org.apache.solr.client.solrj.request.CollectionAdminRequest;
import org.apache.solr.client.solrj.response.CollectionAdminResponse;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.cloud.MiniSolrCloudCluster;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.ltr.feature.FieldValueFeature;
import org.apache.solr.ltr.feature.OriginalScoreFeature;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LinearModel;
import org.apache.solr.util.RestTestHarness;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.AfterClass;
import org.junit.Test;
import static java.util.stream.Collectors.toList;
public class TestLTROnSolrCloud extends TestRerankBase {
private MiniSolrCloudCluster solrCluster;
String solrconfig = "solrconfig-ltr.xml";
String schema = "schema.xml";
SortedMap<ServletHolder,String> extraServlets = null;
@Override
public void setUp() throws Exception {
super.setUp();
extraServlets = setupTestInit(solrconfig, schema, true);
System.setProperty("enable.update.log", "true");
int numberOfShards = random().nextInt(4)+1;
int numberOfReplicas = random().nextInt(2)+1;
int maxShardsPerNode = random().nextInt(4)+1;
int numberOfNodes = (numberOfShards*numberOfReplicas + (maxShardsPerNode-1))/maxShardsPerNode;
setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes, maxShardsPerNode);
}
@Override
public void tearDown() throws Exception {
restTestHarness.close();
restTestHarness = null;
solrCluster.shutdown();
super.tearDown();
}
@Test
// commented 4-Sep-2018 @LuceneTestCase.BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 2-Aug-2018
// commented out on: 24-Dec-2018 @BadApple(bugUrl="https://issues.apache.org/jira/browse/SOLR-12028") // 14-Oct-2018
public void testSimpleQuery() throws Exception {
// will randomly pick a configuration with [1..5] shards and [1..3] replicas
// Test regular query, it will sort the documents by inverse
// popularity (the less popular, docid == 1, will be in the first
// position
SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))");
query.setRequestHandler("/query");
query.setFields("*,score");
query.setParam("rows", "8");
QueryResponse queryResponse =
solrCluster.getSolrClient().query(COLLECTION,query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("1", queryResponse.getResults().get(0).get("id").toString());
assertEquals("2", queryResponse.getResults().get(1).get("id").toString());
assertEquals("3", queryResponse.getResults().get(2).get("id").toString());
assertEquals("4", queryResponse.getResults().get(3).get("id").toString());
assertEquals("5", queryResponse.getResults().get(4).get("id").toString());
assertEquals("6", queryResponse.getResults().get(5).get("id").toString());
assertEquals("7", queryResponse.getResults().get(6).get("id").toString());
assertEquals("8", queryResponse.getResults().get(7).get("id").toString());
final Float original_result0_score = (Float)queryResponse.getResults().get(0).get("score");
final Float original_result1_score = (Float)queryResponse.getResults().get(1).get("score");
final Float original_result2_score = (Float)queryResponse.getResults().get(2).get("score");
final Float original_result3_score = (Float)queryResponse.getResults().get(3).get("score");
final Float original_result4_score = (Float)queryResponse.getResults().get(4).get("score");
final Float original_result5_score = (Float)queryResponse.getResults().get(5).get("score");
final Float original_result6_score = (Float)queryResponse.getResults().get(6).get("score");
final Float original_result7_score = (Float)queryResponse.getResults().get(7).get("score");
final String result0_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "64.0", "c3", "2.0", "original", "0.0", "dvIntFieldFeature", "8.0",
"dvLongFieldFeature", "8.0", "dvFloatFieldFeature", "0.8", "dvDoubleFieldFeature", "0.8",
"dvStrNumFieldFeature", "0.0", "dvStrBoolFieldFeature", "1.0");
final String result1_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "49.0", "c3", "2.0", "original", "1.0", "dvIntFieldFeature", "7.0",
"dvLongFieldFeature", "7.0", "dvFloatFieldFeature", "0.7", "dvDoubleFieldFeature", "0.7",
"dvStrNumFieldFeature", "1.0", "dvStrBoolFieldFeature", "0.0");
final String result2_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "36.0", "c3", "2.0", "original", "2.0", "dvIntFieldFeature", "6.0",
"dvLongFieldFeature", "6.0", "dvFloatFieldFeature", "0.6", "dvDoubleFieldFeature", "0.6",
"dvStrNumFieldFeature", "0.0", "dvStrBoolFieldFeature", "1.0");
final String result3_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "25.0", "c3", "2.0", "original", "3.0", "dvIntFieldFeature", "5.0",
"dvLongFieldFeature", "5.0", "dvFloatFieldFeature", "0.5", "dvDoubleFieldFeature", "0.5",
"dvStrNumFieldFeature", "1.0", "dvStrBoolFieldFeature", "0.0");
final String result4_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "16.0", "c3", "2.0", "original", "4.0", "dvIntFieldFeature", "4.0",
"dvLongFieldFeature", "4.0", "dvFloatFieldFeature", "0.4", "dvDoubleFieldFeature", "0.4",
"dvStrNumFieldFeature", "0.0", "dvStrBoolFieldFeature", "1.0");
final String result5_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "9.0", "c3", "2.0", "original", "5.0", "dvIntFieldFeature", "3.0",
"dvLongFieldFeature", "3.0", "dvFloatFieldFeature", "0.3", "dvDoubleFieldFeature", "0.3",
"dvStrNumFieldFeature", "1.0", "dvStrBoolFieldFeature", "0.0");
final String result6_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "4.0", "c3", "2.0", "original", "6.0", "dvIntFieldFeature", "2.0",
"dvLongFieldFeature", "2.0", "dvFloatFieldFeature", "0.2", "dvDoubleFieldFeature", "0.2",
"dvStrNumFieldFeature", "0.0", "dvStrBoolFieldFeature", "1.0");
final String result7_features = FeatureLoggerTestUtils.toFeatureVector(
"powpularityS", "1.0", "c3", "2.0", "original", "7.0", "dvIntFieldFeature", "-1.0",
"dvLongFieldFeature", "-2.0", "dvFloatFieldFeature", "-3.0", "dvDoubleFieldFeature", "-4.0",
"dvStrNumFieldFeature", "-5.0", "dvStrBoolFieldFeature", "0.0");
// Test feature vectors returned (without re-ranking)
query.setFields("*,score,features:[fv store=test]");
queryResponse =
solrCluster.getSolrClient().query(COLLECTION,query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("1", queryResponse.getResults().get(0).get("id").toString());
assertEquals("2", queryResponse.getResults().get(1).get("id").toString());
assertEquals("3", queryResponse.getResults().get(2).get("id").toString());
assertEquals("4", queryResponse.getResults().get(3).get("id").toString());
assertEquals("5", queryResponse.getResults().get(4).get("id").toString());
assertEquals("6", queryResponse.getResults().get(5).get("id").toString());
assertEquals("7", queryResponse.getResults().get(6).get("id").toString());
assertEquals("8", queryResponse.getResults().get(7).get("id").toString());
assertEquals(original_result0_score, queryResponse.getResults().get(0).get("score"));
assertEquals(original_result1_score, queryResponse.getResults().get(1).get("score"));
assertEquals(original_result2_score, queryResponse.getResults().get(2).get("score"));
assertEquals(original_result3_score, queryResponse.getResults().get(3).get("score"));
assertEquals(original_result4_score, queryResponse.getResults().get(4).get("score"));
assertEquals(original_result5_score, queryResponse.getResults().get(5).get("score"));
assertEquals(original_result6_score, queryResponse.getResults().get(6).get("score"));
assertEquals(original_result7_score, queryResponse.getResults().get(7).get("score"));
assertEquals(result7_features,
queryResponse.getResults().get(0).get("features").toString());
assertEquals(result6_features,
queryResponse.getResults().get(1).get("features").toString());
assertEquals(result5_features,
queryResponse.getResults().get(2).get("features").toString());
assertEquals(result4_features,
queryResponse.getResults().get(3).get("features").toString());
assertEquals(result3_features,
queryResponse.getResults().get(4).get("features").toString());
assertEquals(result2_features,
queryResponse.getResults().get(5).get("features").toString());
assertEquals(result1_features,
queryResponse.getResults().get(6).get("features").toString());
assertEquals(result0_features,
queryResponse.getResults().get(7).get("features").toString());
// Test feature vectors returned (with re-ranking)
query.setFields("*,score,features:[fv]");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs=8}");
queryResponse =
solrCluster.getSolrClient().query(COLLECTION,query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("8", queryResponse.getResults().get(0).get("id").toString());
assertEquals(result0_features,
queryResponse.getResults().get(0).get("features").toString());
assertEquals("7", queryResponse.getResults().get(1).get("id").toString());
assertEquals(result1_features,
queryResponse.getResults().get(1).get("features").toString());
assertEquals("6", queryResponse.getResults().get(2).get("id").toString());
assertEquals(result2_features,
queryResponse.getResults().get(2).get("features").toString());
assertEquals("5", queryResponse.getResults().get(3).get("id").toString());
assertEquals(result3_features,
queryResponse.getResults().get(3).get("features").toString());
assertEquals("4", queryResponse.getResults().get(4).get("id").toString());
assertEquals(result4_features,
queryResponse.getResults().get(4).get("features").toString());
assertEquals("3", queryResponse.getResults().get(5).get("id").toString());
assertEquals(result5_features,
queryResponse.getResults().get(5).get("features").toString());
assertEquals("2", queryResponse.getResults().get(6).get("id").toString());
assertEquals(result6_features,
queryResponse.getResults().get(6).get("features").toString());
assertEquals("1", queryResponse.getResults().get(7).get("id").toString());
assertEquals(result7_features,
queryResponse.getResults().get(7).get("features").toString());
}
private void setupSolrCluster(int numShards, int numReplicas, int numServers, int maxShardsPerNode) throws Exception {
JettyConfig jc = buildJettyConfig("/solr");
jc = JettyConfig.builder(jc).withServlets(extraServlets).build();
solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome.toPath(), jc);
File configDir = tmpSolrHome.toPath().resolve("collection1/conf").toFile();
solrCluster.uploadConfigSet(configDir.toPath(), "conf1");
solrCluster.getSolrClient().setDefaultCollection(COLLECTION);
createCollection(COLLECTION, "conf1", numShards, numReplicas, maxShardsPerNode);
indexDocuments(COLLECTION);
for (JettySolrRunner solrRunner : solrCluster.getJettySolrRunners()) {
if (!solrRunner.getCoreContainer().getCores().isEmpty()){
String coreName = solrRunner.getCoreContainer().getCores().iterator().next().getName();
restTestHarness = new RestTestHarness(() -> solrRunner.getBaseUrl().toString() + "/" + coreName);
break;
}
}
loadModelsAndFeatures();
}
private void createCollection(String name, String config, int numShards, int numReplicas, int maxShardsPerNode)
throws Exception {
CollectionAdminResponse response;
CollectionAdminRequest.Create create =
CollectionAdminRequest.createCollection(name, config, numShards, numReplicas);
create.setMaxShardsPerNode(maxShardsPerNode);
response = create.process(solrCluster.getSolrClient());
if (response.getStatus() != 0 || response.getErrorMessages() != null) {
fail("Could not create collection. Response" + response.toString());
}
ZkStateReader zkStateReader = solrCluster.getSolrClient().getZkStateReader();
solrCluster.waitForActiveCollection(name, numShards, numShards * numReplicas);
}
void indexDocument(String collection, String id, String title, String description, int popularity)
throws Exception{
SolrInputDocument doc = new SolrInputDocument();
doc.setField("id", id);
doc.setField("title", title);
doc.setField("description", description);
doc.setField("popularity", popularity);
if (popularity != 1) {
// check that empty values will be read as default
doc.setField("dvIntField", popularity);
doc.setField("dvLongField", popularity);
doc.setField("dvFloatField", ((float) popularity) / 10);
doc.setField("dvDoubleField", ((double) popularity) / 10);
doc.setField("dvStrNumField", popularity % 2 == 0 ? "F" : "T");
doc.setField("dvStrBoolField", popularity % 2 == 0 ? "T" : "F");
}
solrCluster.getSolrClient().add(collection, doc);
}
private void indexDocuments(final String collection)
throws Exception {
final int collectionSize = 8;
// put documents in random order to check that advanceExact is working correctly
List<Integer> docIds = IntStream.rangeClosed(1, collectionSize).boxed().collect(toList());
Collections.shuffle(docIds, random());
int docCounter = 1;
for (int docId : docIds) {
final int popularity = docId;
indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity);
// maybe commit in the middle in order to check that everything works fine for multi-segment case
if (docCounter == collectionSize / 2 && random().nextBoolean()) {
solrCluster.getSolrClient().commit(collection);
}
docCounter++;
}
solrCluster.getSolrClient().commit(collection, true, true);
}
private void loadModelsAndFeatures() throws Exception {
final String featureStore = "test";
final String[] featureNames = new String[]{"powpularityS", "c3", "original", "dvIntFieldFeature",
"dvLongFieldFeature", "dvFloatFieldFeature", "dvDoubleFieldFeature", "dvStrNumFieldFeature", "dvStrBoolFieldFeature"};
final String jsonModelParams = "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0,\"original\":0.1," +
"\"dvIntFieldFeature\":0.1,\"dvLongFieldFeature\":0.1," +
"\"dvFloatFieldFeature\":0.1,\"dvDoubleFieldFeature\":0.1,\"dvStrNumFieldFeature\":0.1,\"dvStrBoolFieldFeature\":0.1}}";
loadFeature(
featureNames[0],
SolrFeature.class.getName(),
featureStore,
"{\"q\":\"{!func}pow(popularity,2)\"}"
);
loadFeature(
featureNames[1],
ValueFeature.class.getName(),
featureStore,
"{\"value\":2}"
);
loadFeature(
featureNames[2],
OriginalScoreFeature.class.getName(),
featureStore,
null
);
loadFeature(
featureNames[3],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvIntField\"}"
);
loadFeature(
featureNames[4],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvLongField\"}"
);
loadFeature(
featureNames[5],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvFloatField\"}"
);
loadFeature(
featureNames[6],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvDoubleField\",\"defaultValue\":-4.0}"
);
loadFeature(
featureNames[7],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvStrNumField\",\"defaultValue\":-5}"
);
loadFeature(
featureNames[8],
FieldValueFeature.class.getName(),
featureStore,
"{\"field\":\"dvStrBoolField\"}"
);
loadModel(
"powpularityS-model",
LinearModel.class.getName(),
featureNames,
featureStore,
jsonModelParams
);
reloadCollection(COLLECTION);
}
private void reloadCollection(String collection) throws Exception {
CollectionAdminRequest.Reload reloadRequest = CollectionAdminRequest.reloadCollection(collection);
CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient());
assertEquals(0, response.getStatus());
assertTrue(response.isSuccess());
}
@AfterClass
public static void after() throws Exception {
if (null != tmpSolrHome) {
FileUtils.deleteDirectory(tmpSolrHome);
tmpSolrHome = null;
}
System.clearProperty("managed.schema.mutable");
}
}