blob: b31a6ba770dc1eeabf0ec5ad1e83a610410fb6b5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.store.FeatureStore;
import org.junit.Test;
import org.mockito.Mockito;
public class TestWrapperModel extends TestRerankBase {
private static class StubWrapperModel extends WrapperModel {
private StubWrapperModel(String name) {
this(name, Collections.emptyList(), Collections.emptyList());
}
private StubWrapperModel(String name, List<Feature> features, List<Normalizer> norms) {
super(name, features, norms, FeatureStore.DEFAULT_FEATURE_STORE_NAME, features, Collections.emptyMap());
}
@Override
public Map<String, Object> fetchModelMap() throws ModelException {
return null;
}
}
private static LTRScoringModel createMockWrappedModel(String featureStoreName,
List<Feature> features, List<Normalizer> norms) {
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
Mockito.doReturn(featureStoreName).when(wrappedModel).getFeatureStoreName();
Mockito.doReturn(features).when(wrappedModel).getFeatures();
Mockito.doReturn(norms).when(wrappedModel).getNorms();
return wrappedModel;
}
@Test
public void testValidate() throws Exception {
WrapperModel wrapperModel = new StubWrapperModel("testModel");
wrapperModel.validate();
// wrapper model with features
WrapperModel wrapperModelWithFeatures = new StubWrapperModel("testModel",
Collections.singletonList(new ValueFeature("val", Collections.emptyMap())), Collections.emptyList());
ModelException e = expectThrows(ModelException.class, wrapperModelWithFeatures::validate);
assertEquals("features must be empty for the wrapper model testModel", e.getMessage());
// wrapper model with norms
WrapperModel wrapperModelWithNorms = new StubWrapperModel("testModel",
Collections.emptyList(), Collections.singletonList(IdentityNormalizer.INSTANCE));
e = expectThrows(ModelException.class, wrapperModelWithNorms::validate);
assertEquals("norms must be empty for the wrapper model testModel", e.getMessage());
assumeWorkingMockito();
// update valid model
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
wrapperModel.updateModel(wrappedModel);
}
// update invalid model (feature store mismatch)
{
LTRScoringModel wrappedModel =
createMockWrappedModel("wrappedFeatureStore",
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
e = expectThrows(ModelException.class, () -> wrapperModel.updateModel(wrappedModel));
assertEquals("wrapper feature store name (_DEFAULT_) must match the wrapped feature store name (wrappedFeatureStore)", e.getMessage());
}
// update invalid model (no features)
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Collections.emptyList(),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
e = expectThrows(ModelException.class, () -> wrapperModel.updateModel(wrappedModel));
assertEquals("no features declared for model testModel", e.getMessage());
}
// update invalid model (no norms)
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Collections.emptyList()
);
e = expectThrows(ModelException.class, () -> wrapperModel.updateModel(wrappedModel));
assertEquals("counted 2 features and 0 norms in model testModel", e.getMessage());
}
}
@Test
public void testMethodOverridesAndDelegation() throws Exception {
assumeWorkingMockito();
final int overridableMethodCount = testOverwrittenMethods();
final int methodCount = testDelegateMethods();
assertEquals("method count mismatch", overridableMethodCount, methodCount);
}
private int testOverwrittenMethods() throws Exception {
int overridableMethodCount = 0;
for (final Method superClassMethod : LTRScoringModel.class.getDeclaredMethods()) {
final int modifiers = superClassMethod.getModifiers();
if (Modifier.isFinal(modifiers)) continue;
if (Modifier.isStatic(modifiers)) continue;
++overridableMethodCount;
if (Arrays.asList(
"getName", // the wrapper model's name is its own name i.e. _not_ the name of the wrapped model
"getFeatureStoreName", // wrapper and wrapped model feature store should match, so need not override
"getParams" // the wrapper model's params are its own params i.e. _not_ the params of the wrapped model
).contains(superClassMethod.getName())) {
try {
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
superClassMethod.getName(),
superClassMethod.getParameterTypes());
fail(WrapperModel.class + " need not override\n'" + superClassMethod + "'"
+ " but it does override\n'" + subClassMethod + "'");
} catch (NoSuchMethodException e) {
// ok
}
} else {
try {
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
superClassMethod.getName(),
superClassMethod.getParameterTypes());
assertEquals("getReturnType() difference",
superClassMethod.getReturnType(),
subClassMethod.getReturnType());
} catch (NoSuchMethodException e) {
fail(WrapperModel.class + " needs to override '" + superClassMethod + "'");
}
}
}
return overridableMethodCount;
}
private int testDelegateMethods() throws Exception {
int methodCount = 0;
WrapperModel wrapperModel = Mockito.spy(new StubWrapperModel("testModel"));
// ignore validate in this test case
Mockito.doNothing().when(wrapperModel).validate();
++methodCount;
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
wrapperModel.updateModel(wrappedModel);
// cannot be stubbed or verified
++methodCount; // toString
++methodCount; // hashCode
++methodCount; // equals
// getFeatureStoreName : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getFeatureStoreName();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getFeatureStoreName();
// getName : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getName();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getName();
// getParams : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getParams();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getParams();
// getNorms : delegate
Mockito.reset(wrappedModel);
wrapperModel.getNorms();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getNorms();
// ramBytesUseD() : delegate
Mockito.reset(wrappedModel);
wrapperModel.ramBytesUsed();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).ramBytesUsed();
// getFeatures : delegate
Mockito.reset(wrappedModel);
wrapperModel.getFeatures();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getFeatures();
// getAllFeatures : delegate
Mockito.reset(wrappedModel);
wrapperModel.getAllFeatures();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getAllFeatures();
// score : delegate
Mockito.reset(wrappedModel);
wrapperModel.score(null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).score(null);
// normalizeFeaturesInPlace : delegate
Mockito.reset(wrappedModel);
wrapperModel.normalizeFeaturesInPlace(null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).normalizeFeaturesInPlace(null);
// getNormalizerExplanation : delegate
Mockito.reset(wrappedModel);
wrapperModel.getNormalizerExplanation(null, 0);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getNormalizerExplanation(null, 0);
// explain : delegate
Mockito.reset(wrappedModel);
wrapperModel.explain(null, 0, 0.0f, null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).explain(null, 0, 0.0f, null);
return methodCount;
}
}