blob: 0627d79b8e2bfca7058bc8c88d39dbef664513a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.store.rest;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.AdapterModel;
import org.apache.solr.ltr.model.WrapperModel;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.ModelStore;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.BaseSolrResource;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.rest.ManagedResourceStorage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Menaged resource for storing a model
*/
public class ManagedModelStore extends ManagedResource implements ManagedResource.ChildResourceSupport {
public static void registerManagedModelStore(SolrResourceLoader solrResourceLoader,
ManagedResourceObserver managedResourceObserver) {
solrResourceLoader.getManagedResourceRegistry().registerManagedResource(
REST_END_POINT,
ManagedModelStore.class,
managedResourceObserver);
}
public static ManagedModelStore getManagedModelStore(SolrCore core) {
return (ManagedModelStore) core.getRestManager()
.getManagedResource(REST_END_POINT);
}
/** the model store rest endpoint **/
public static final String REST_END_POINT = "/schema/model-store";
/**
* Managed model store: the name of the attribute containing all the models of
* a model store
**/
private static final String MODELS_JSON_FIELD = "models";
/** name of the attribute containing a class **/
static final String CLASS_KEY = "class";
/** name of the attribute containing the features **/
static final String FEATURES_KEY = "features";
/** name of the attribute containing a name **/
static final String NAME_KEY = "name";
/** name of the attribute containing a normalizer **/
static final String NORM_KEY = "norm";
/** name of the attribute containing parameters **/
static final String PARAMS_KEY = "params";
/** name of the attribute containing a store **/
static final String STORE_KEY = "store";
private final ModelStore store;
private ManagedFeatureStore managedFeatureStore;
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public ManagedModelStore(String resourceId, SolrResourceLoader loader,
ManagedResourceStorage.StorageIO storageIO) throws SolrException {
super(resourceId, loader, storageIO);
store = new ModelStore();
}
public void setManagedFeatureStore(ManagedFeatureStore managedFeatureStore) {
log.info("INIT model store");
this.managedFeatureStore = managedFeatureStore;
}
public ManagedFeatureStore getManagedFeatureStore() {
return managedFeatureStore;
}
private Object managedData;
@SuppressWarnings("unchecked")
@Override
protected void onManagedDataLoadedFromStorage(NamedList<?> managedInitArgs,
Object managedData) throws SolrException {
store.clear();
// the managed models on the disk or on zookeeper will be loaded in a lazy
// way, since we need to set the managed features first (unfortunately
// managed resources do not
// decouple the creation of a managed resource with the reading of the data
// from the storage)
this.managedData = managedData;
}
public void loadStoredModels() {
log.info("------ managed models ~ loading ------");
if ((managedData != null) && (managedData instanceof List)) {
@SuppressWarnings({"unchecked"})
final List<Map<String,Object>> up = (List<Map<String,Object>>) managedData;
for (final Map<String,Object> u : up) {
addModelFromMap(u);
}
}
}
private void addModelFromMap(Map<String,Object> modelMap) {
try {
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, modelMap, managedFeatureStore);
addModel(algo);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
public synchronized void addModel(LTRScoringModel ltrScoringModel) throws ModelException {
try {
if (log.isInfoEnabled()) {
log.info("adding model {}", ltrScoringModel.getName());
}
store.addModel(ltrScoringModel);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
@SuppressWarnings("unchecked")
@Override
protected Object applyUpdatesToManagedData(Object updates) {
if (updates instanceof List) {
final List<Map<String,Object>> up = (List<Map<String,Object>>) updates;
for (final Map<String,Object> u : up) {
addModelFromMap(u);
}
}
if (updates instanceof Map) {
final Map<String,Object> map = (Map<String,Object>) updates;
addModelFromMap(map);
}
return modelsAsManagedResources(store.getModels());
}
@Override
public synchronized void doDeleteChild(BaseSolrResource endpoint, String childId) {
store.delete(childId);
storeManagedData(applyUpdatesToManagedData(null));
}
/**
* Called to retrieve a named part (the given childId) of the resource at the
* given endpoint. Note: since we have a unique child managed store we ignore
* the childId.
*/
@Override
public void doGet(BaseSolrResource endpoint, String childId) {
final SolrQueryResponse response = endpoint.getSolrResponse();
response.add(MODELS_JSON_FIELD,
modelsAsManagedResources(store.getModels()));
}
public LTRScoringModel getModel(String modelName) {
// this function replicates getModelStore().getModel(modelName), but
// it simplifies the testing (we can avoid to mock also a ModelStore).
return store.getModel(modelName);
}
@Override
public String toString() {
return "ManagedModelStore [store=" + store + ", featureStores="
+ managedFeatureStore + "]";
}
/**
* Returns the available models as a list of Maps objects. After an update the
* managed resources needs to return the resources in this format in order to
* store in json somewhere (zookeeper, disk...)
*
*
* @return the available models as a list of Maps objects
*/
private static List<Object> modelsAsManagedResources(List<LTRScoringModel> models) {
final List<Object> list = new ArrayList<>(models.size());
for (final LTRScoringModel model : models) {
list.add(toLTRScoringModelMap(model));
}
return list;
}
@SuppressWarnings("unchecked")
public static LTRScoringModel fromLTRScoringModelMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> modelMap, ManagedFeatureStore managedFeatureStore) {
final FeatureStore featureStore =
managedFeatureStore.getFeatureStore((String) modelMap.get(STORE_KEY));
final List<Feature> features = new ArrayList<>();
final List<Normalizer> norms = new ArrayList<>();
final List<Object> featureList = (List<Object>) modelMap.get(FEATURES_KEY);
if (featureList != null) {
for (final Object feature : featureList) {
final Map<String,Object> featureMap = (Map<String,Object>) feature;
features.add(lookupFeatureFromFeatureMap(featureMap, featureStore));
norms.add(createNormalizerFromFeatureMap(solrResourceLoader, featureMap));
}
}
final LTRScoringModel ltrScoringModel = LTRScoringModel.getInstance(solrResourceLoader,
(String) modelMap.get(CLASS_KEY), // modelClassName
(String) modelMap.get(NAME_KEY), // modelName
features,
norms,
featureStore.getName(),
featureStore.getFeatures(),
(Map<String,Object>) modelMap.get(PARAMS_KEY));
if (ltrScoringModel instanceof AdapterModel) {
initAdapterModel(solrResourceLoader, (AdapterModel)ltrScoringModel, managedFeatureStore);
}
return ltrScoringModel;
}
private static void initAdapterModel(SolrResourceLoader solrResourceLoader,
AdapterModel adapterModel, ManagedFeatureStore managedFeatureStore) {
adapterModel.init(solrResourceLoader);
if (adapterModel instanceof WrapperModel) {
initWrapperModel(solrResourceLoader, (WrapperModel)adapterModel, managedFeatureStore);
}
}
private static void initWrapperModel(SolrResourceLoader solrResourceLoader,
WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) {
final LTRScoringModel model = fromLTRScoringModelMap(
solrResourceLoader,
wrapperModel.fetchModelMap(),
managedFeatureStore);
if (model instanceof AdapterModel) {
initAdapterModel(solrResourceLoader, (AdapterModel)model, managedFeatureStore);
}
wrapperModel.updateModel(model);
}
private static LinkedHashMap<String,Object> toLTRScoringModelMap(LTRScoringModel model) {
final LinkedHashMap<String,Object> modelMap = new LinkedHashMap<>(5, 1.0f);
modelMap.put(NAME_KEY, model.getName());
modelMap.put(CLASS_KEY, model.getClass().getName());
modelMap.put(STORE_KEY, model.getFeatureStoreName());
final List<Map<String,Object>> features = new ArrayList<>();
if (!(model instanceof WrapperModel)) {
final List<Feature> featuresList = model.getFeatures();
final List<Normalizer> normsList = model.getNorms();
for (int ii = 0; ii < featuresList.size(); ++ii) {
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
}
}
modelMap.put(FEATURES_KEY, features);
modelMap.put(PARAMS_KEY, model.getParams());
return modelMap;
}
private static Feature lookupFeatureFromFeatureMap(Map<String, Object> featureMap, FeatureStore featureStore)
{
final String featureName = (String) featureMap.get(NAME_KEY);
Feature extractedFromStore = featureName == null ? null : featureStore.get(featureName);
if (extractedFromStore == null) {
if (featureStore.getFeatures().isEmpty()) {
throw new ModelException("Missing or empty feature store: " + featureStore.getName());
} else {
throw new ModelException("Feature: " + featureName + " not found in store: " + featureStore.getName());
}
}
return extractedFromStore;
}
@SuppressWarnings("unchecked")
private static Normalizer createNormalizerFromFeatureMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> featureMap) {
final Map<String,Object> normMap = (Map<String,Object>)featureMap.get(NORM_KEY);
return (normMap == null ? IdentityNormalizer.INSTANCE
: fromNormalizerMap(solrResourceLoader, normMap));
}
private static LinkedHashMap<String,Object> toFeatureMap(Feature feature, Normalizer norm) {
final LinkedHashMap<String,Object> map = new LinkedHashMap<String,Object>(2, 1.0f);
map.put(NAME_KEY, feature.getName());
map.put(NORM_KEY, toNormalizerMap(norm));
return map;
}
private static Normalizer fromNormalizerMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> normMap) {
final String className = (String) normMap.get(CLASS_KEY);
@SuppressWarnings("unchecked")
final Map<String,Object> params = (Map<String,Object>) normMap.get(PARAMS_KEY);
return Normalizer.getInstance(solrResourceLoader, className, params);
}
private static LinkedHashMap<String,Object> toNormalizerMap(Normalizer norm) {
final LinkedHashMap<String,Object> normalizer = new LinkedHashMap<>(2, 1.0f);
normalizer.put(CLASS_KEY, norm.getClass().getName());
final LinkedHashMap<String,Object> params = norm.paramsToMap();
if (params != null) {
normalizer.put(PARAMS_KEY, params);
}
return normalizer;
}
}