blob: 1622a3542bb27be9388ef01ad0ef599fe52799cf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.api;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.BooleanParam;
import org.apache.flink.ml.param.DoubleArrayArrayParam;
import org.apache.flink.ml.param.DoubleArrayParam;
import org.apache.flink.ml.param.DoubleParam;
import org.apache.flink.ml.param.FloatArrayParam;
import org.apache.flink.ml.param.FloatParam;
import org.apache.flink.ml.param.IntArrayParam;
import org.apache.flink.ml.param.IntParam;
import org.apache.flink.ml.param.LongArrayParam;
import org.apache.flink.ml.param.LongParam;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidator;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.param.StringArrayArrayParam;
import org.apache.flink.ml.param.StringArrayParam;
import org.apache.flink.ml.param.StringParam;
import org.apache.flink.ml.param.VectorParam;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/** Tests the behavior of Stage and WithParams. */
public class StageTest {
// A WithParams subclass which has one parameter for each pre-defined parameter type.
private interface MyParams<T> extends WithParams<T> {
Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false);
Param<Integer> INT_PARAM =
new IntParam("intParam", "Description", 1, ParamValidators.lt(100));
// This param might be regarded as integer type given a small value.
Param<Long> LONG_PARAM =
new LongParam("longParam", "Description", 2L, ParamValidators.lt(100));
// This param must be regarded as long type given a large enough value.
Param<Long> LONG_PARAM2 =
new LongParam(
"longParam2",
"Description",
Integer.MAX_VALUE + 1L,
ParamValidators.lt(Integer.MAX_VALUE + 100L));
Param<Float> FLOAT_PARAM =
new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100));
Param<Double> DOUBLE_PARAM =
new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100));
Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5");
Param<Integer[]> INT_ARRAY_PARAM =
new IntArrayParam("intArrayParam", "Description", new Integer[] {6, 7});
Param<Long[]> LONG_ARRAY_PARAM =
new LongArrayParam(
"longArrayParam",
"Description",
new Long[] {8L, 9L},
ParamValidators.alwaysTrue());
Param<Float[]> FLOAT_ARRAY_PARAM =
new FloatArrayParam("floatArrayParam", "Description", new Float[] {10.0f, 11.0f});
Param<Double[]> DOUBLE_ARRAY_PARAM =
new DoubleArrayParam(
"doubleArrayParam",
"Description",
new Double[] {12.0, 13.0},
ParamValidators.alwaysTrue());
Param<String[]> STRING_ARRAY_PARAM =
new StringArrayParam("stringArrayParam", "Description", new String[] {"14", "15"});
Param<String[][]> STRING_ARRAY_ARRAY_PARAM =
new StringArrayArrayParam(
"stringArrayArrayParam",
"Description",
new String[][] {new String[] {"14", "15"}});
Param<Double[][]> DOUBLE_ARRAY_ARRAY_PARAM =
new DoubleArrayArrayParam(
"doubleArrayArrayParam",
"Description",
new Double[][] {new Double[] {14.0, 15.0}, new Double[] {16.0, 17.0}});
Param<Vector> VECTOR_PARAM =
new VectorParam("vectorParam", "Description", Vectors.dense(1.0, 2.0, 3.0));
}
/**
* A Stage subclass which inherits all parameters from MyParams and defines an extra parameter.
*/
public static class MyStage implements Stage<MyStage>, MyParams<MyStage> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public final Param<Integer> extraIntParam =
new IntParam("extraIntParam", "Description", 20, ParamValidators.alwaysTrue());
public final Param<Integer> paramWithNullDefault =
new IntParam(
"paramWithNullDefault",
"Must be explicitly set with a non-null value",
null,
ParamValidators.notNull());
public MyStage() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public Map<Param<?>, Object> getParamMap() {
return paramMap;
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
}
public static MyStage load(StreamTableEnvironment tEnv, String path) throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
/** A Stage subclass without the static load() method. */
public static class MyStageWithoutLoad implements Stage<MyStage>, MyParams<MyStage> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public MyStageWithoutLoad() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
}
@Override
public Map<Param<?>, Object> getParamMap() {
return paramMap;
}
}
// Asserts that m1 and m2 are equivalent.
private static void assertParamMapEquals(Map<Param<?>, Object> m1, Map<Param<?>, Object> m2) {
Assert.assertTrue(m1 != null && m2 != null);
Assert.assertEquals(m1.size(), m2.size());
for (Map.Entry<Param<?>, Object> entry : m1.entrySet()) {
Assert.assertTrue(m2.containsKey(entry.getKey()));
Object v1 = entry.getValue();
Object v2 = m2.get(entry.getKey());
if (v1 == null || v2 == null) {
Assert.assertTrue(v1 == null && v2 == null);
} else if (v1.getClass().isArray() && v2.getClass().isArray()) {
Assert.assertArrayEquals((Object[]) v1, (Object[]) v2);
} else {
Assert.assertEquals(v1, v2);
}
}
}
// Saves and loads the given stage. And verifies that the loaded stage has same parameter values
// as the original stage.
private static Stage<?> validateStageSaveLoad(
StreamTableEnvironment tEnv, Stage<?> stage, Map<String, Object> paramOverrides)
throws IOException {
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = stage.getParam(entry.getKey());
ReadWriteUtils.setParam(stage, param, entry.getValue());
}
String path = Files.createTempDirectory("").toString();
stage.save(path);
try {
stage.save(path);
Assert.fail("Expected IOException");
} catch (IOException e) {
// This is expected.
}
Stage<?> loadedStage = ReadWriteUtils.loadStage(tEnv, path);
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = loadedStage.getParam(entry.getKey());
Assert.assertEquals(entry.getValue(), loadedStage.get(param));
}
assertParamMapEquals(stage.getParamMap(), loadedStage.getParamMap());
return loadedStage;
}
@Test
public void testParamSetValueWithName() {
MyStage stage = new MyStage();
Param<Integer> paramA = MyParams.INT_PARAM;
stage.set(paramA, 2);
Assert.assertEquals(2, (int) stage.get(paramA));
Param<Integer> paramB = stage.getParam("intParam");
stage.set(paramB, 3);
Assert.assertEquals(3, (int) stage.get(paramB));
Param<Integer> paramC = stage.getParam("extraIntParam");
stage.set(paramC, 50);
Assert.assertEquals(50, (int) stage.get(paramC));
}
@Test
public void testParamWithNullDefault() {
MyStage stage = new MyStage();
try {
stage.get(stage.paramWithNullDefault);
Assert.fail("Expected IllegalArgumentException");
} catch (IllegalArgumentException e) {
Assert.assertTrue(e.getMessage().contains("should not be null"));
}
stage.set(stage.paramWithNullDefault, 3);
Assert.assertEquals(3, (int) stage.get(stage.paramWithNullDefault));
}
private static <T> void assertInvalidValue(Stage<?> stage, Param<T> param, T value) {
try {
stage.set(param, value);
Assert.fail("Expected IllegalArgumentException");
} catch (IllegalArgumentException e) {
Assert.assertTrue(e.getMessage().contains("invalid value"));
}
}
private static <T> void assertInvalidClass(Stage<?> stage, Param<T> param, Object value) {
try {
stage.set(param, (T) value);
Assert.fail("Expected ClassCastException");
} catch (ClassCastException e) {
Assert.assertTrue(e.getMessage().contains("incompatible class"));
}
}
@Test
public void testSetUndefinedParam() {
MyStage stage = new MyStage();
Param<Integer> param = new IntParam("anotherIntParam", "Not defined on MyStage", 1);
try {
stage.set(param, 2);
Assert.fail("Expected IllegalArgumentException");
} catch (IllegalArgumentException e) {
Assert.assertTrue(e.getMessage().contains(MyStage.class.getName()));
}
}
@Test
public void testParamSetInvalidValue() {
MyStage stage = new MyStage();
assertInvalidValue(stage, MyParams.INT_PARAM, 100);
assertInvalidValue(stage, MyParams.LONG_PARAM, 100L);
assertInvalidValue(stage, MyParams.LONG_PARAM2, Integer.MAX_VALUE + 100L);
assertInvalidValue(stage, MyParams.FLOAT_PARAM, 100.0f);
assertInvalidValue(stage, MyParams.DOUBLE_PARAM, 100.0);
assertInvalidClass(stage, MyParams.INT_PARAM, "100");
assertInvalidClass(stage, MyParams.STRING_PARAM, 100);
Param<Integer> param = stage.getParam("stringParam");
assertInvalidClass(stage, param, 50);
}
@Test
public void testParamSetValidValue() {
MyStage stage = new MyStage();
stage.set(MyParams.BOOLEAN_PARAM, true);
Assert.assertEquals(true, stage.get(MyParams.BOOLEAN_PARAM));
stage.set(MyParams.INT_PARAM, 50);
Assert.assertEquals(50, (int) stage.get(MyParams.INT_PARAM));
stage.set(MyParams.LONG_PARAM, 50L);
Assert.assertEquals(50L, (long) stage.get(MyParams.LONG_PARAM));
stage.set(MyParams.LONG_PARAM2, Integer.MAX_VALUE + 50L);
Assert.assertEquals(Integer.MAX_VALUE + 50L, (long) stage.get(MyParams.LONG_PARAM2));
stage.set(MyParams.FLOAT_PARAM, 50f);
Assert.assertEquals(50f, stage.get(MyParams.FLOAT_PARAM), 0.0001);
stage.set(MyParams.DOUBLE_PARAM, 50.0);
Assert.assertEquals(50, stage.get(MyParams.DOUBLE_PARAM), 0.0001);
stage.set(MyParams.STRING_PARAM, "50");
Assert.assertEquals("50", stage.get(MyParams.STRING_PARAM));
stage.set(MyParams.INT_ARRAY_PARAM, new Integer[] {50, 51});
Assert.assertArrayEquals(new Integer[] {50, 51}, stage.get(MyParams.INT_ARRAY_PARAM));
stage.set(MyParams.LONG_ARRAY_PARAM, new Long[] {50L, 51L});
Assert.assertArrayEquals(new Long[] {50L, 51L}, stage.get(MyParams.LONG_ARRAY_PARAM));
stage.set(MyParams.FLOAT_ARRAY_PARAM, new Float[] {50.0f, 51.0f});
Assert.assertArrayEquals(new Float[] {50.0f, 51.0f}, stage.get(MyParams.FLOAT_ARRAY_PARAM));
stage.set(MyParams.DOUBLE_ARRAY_PARAM, new Double[] {50.0, 51.0});
Assert.assertArrayEquals(new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_PARAM));
stage.set(MyParams.STRING_ARRAY_PARAM, new String[] {"50", "51"});
Assert.assertArrayEquals(new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_PARAM));
stage.set(MyParams.VECTOR_PARAM, Vectors.dense(1, 5, 3));
Assert.assertEquals(Vectors.dense(1, 5, 3), stage.get(MyParams.VECTOR_PARAM));
stage.set(
MyParams.DOUBLE_ARRAY_ARRAY_PARAM,
new Double[][] {new Double[] {50.0, 51.0}, new Double[] {52.0, 53.0}});
Assert.assertEquals(2, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM).length);
Assert.assertArrayEquals(
new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM)[0]);
Assert.assertArrayEquals(
new Double[] {52.0, 53.0}, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM)[1]);
stage.set(
MyParams.STRING_ARRAY_ARRAY_PARAM,
new String[][] {
new String[] {"50", "51"},
new String[] {"52", "53"}
});
Assert.assertEquals(2, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM).length);
Assert.assertArrayEquals(
new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM)[0]);
Assert.assertArrayEquals(
new String[] {"52", "53"}, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM)[1]);
}
@Test
public void testStageSaveLoad() throws IOException {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
stage.set(MyParams.BOOLEAN_PARAM, true);
stage.set(MyParams.INT_PARAM, 50);
stage.set(MyParams.LONG_PARAM, 50L);
stage.set(MyParams.LONG_PARAM2, Integer.MAX_VALUE + 50L);
stage.set(MyParams.FLOAT_PARAM, 50f);
stage.set(MyParams.DOUBLE_PARAM, 50.0);
stage.set(MyParams.STRING_PARAM, "50");
stage.set(MyParams.INT_ARRAY_PARAM, new Integer[] {50, 51});
stage.set(MyParams.LONG_ARRAY_PARAM, new Long[] {50L, 51L});
stage.set(MyParams.FLOAT_ARRAY_PARAM, new Float[] {50.0f, 51.0f});
stage.set(MyParams.DOUBLE_ARRAY_PARAM, new Double[] {50.0, 51.0});
stage.set(MyParams.STRING_ARRAY_PARAM, new String[] {"50", "51"});
stage.set(MyParams.VECTOR_PARAM, Vectors.dense(2, 3, 4));
stage.set(
MyParams.DOUBLE_ARRAY_ARRAY_PARAM,
new Double[][] {new Double[] {50.0, 51.0}, new Double[] {52.0, 53.0}});
stage.set(
MyParams.STRING_ARRAY_ARRAY_PARAM,
new String[][] {
new String[] {"50", "51"},
new String[] {"52", "53"}
});
Stage<?> loadedStage = validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
Assert.assertEquals(1, (int) loadedStage.get(stage.paramWithNullDefault));
Assert.assertEquals(true, loadedStage.get(MyParams.BOOLEAN_PARAM));
Assert.assertEquals(50, (int) loadedStage.get(MyParams.INT_PARAM));
Assert.assertEquals(50L, (long) loadedStage.get(MyParams.LONG_PARAM));
Assert.assertEquals(Integer.MAX_VALUE + 50L, (long) loadedStage.get(MyParams.LONG_PARAM2));
Assert.assertEquals(50f, loadedStage.get(MyParams.FLOAT_PARAM), 0.0001);
Assert.assertEquals(50, loadedStage.get(MyParams.DOUBLE_PARAM), 0.0001);
Assert.assertEquals("50", loadedStage.get(MyParams.STRING_PARAM));
Assert.assertArrayEquals(new Integer[] {50, 51}, loadedStage.get(MyParams.INT_ARRAY_PARAM));
Assert.assertArrayEquals(new Long[] {50L, 51L}, loadedStage.get(MyParams.LONG_ARRAY_PARAM));
Assert.assertArrayEquals(
new Float[] {50.0f, 51.0f}, loadedStage.get(MyParams.FLOAT_ARRAY_PARAM));
Assert.assertArrayEquals(
new Double[] {50.0, 51.0}, loadedStage.get(MyParams.DOUBLE_ARRAY_PARAM));
Assert.assertArrayEquals(
new String[] {"50", "51"}, loadedStage.get(MyParams.STRING_ARRAY_PARAM));
Assert.assertEquals(2, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM).length);
Assert.assertArrayEquals(
new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM)[0]);
Assert.assertArrayEquals(
new Double[] {52.0, 53.0}, stage.get(MyParams.DOUBLE_ARRAY_ARRAY_PARAM)[1]);
Assert.assertEquals(2, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM).length);
Assert.assertArrayEquals(
new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM)[0]);
Assert.assertArrayEquals(
new String[] {"52", "53"}, stage.get(MyParams.STRING_ARRAY_ARRAY_PARAM)[1]);
Assert.assertEquals(Vectors.dense(2, 3, 4), loadedStage.get(MyParams.VECTOR_PARAM));
}
@Test
public void testStageSaveLoadWithParamOverrides() throws IOException {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
Stage<?> loadedStage =
validateStageSaveLoad(
tEnv, stage, Collections.singletonMap("paramWithNullDefault", 10));
Assert.assertEquals(10, (int) loadedStage.get(stage.paramWithNullDefault));
}
@Test
public void testStageLoadWithoutLoadMethod() throws IOException {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStageWithoutLoad stage = new MyStageWithoutLoad();
try {
validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
Assert.fail("Expected RuntimeException");
} catch (RuntimeException e) {
Assert.assertTrue(e.getMessage().contains("not implemented"));
}
}
@Test
public void testValidators() {
ParamValidator<Integer> gt = ParamValidators.gt(10);
Assert.assertFalse(gt.validate(null));
Assert.assertFalse(gt.validate(5));
Assert.assertFalse(gt.validate(10));
Assert.assertTrue(gt.validate(15));
ParamValidator<Integer> gtEq = ParamValidators.gtEq(10);
Assert.assertFalse(gtEq.validate(null));
Assert.assertFalse(gtEq.validate(5));
Assert.assertTrue(gtEq.validate(10));
Assert.assertTrue(gtEq.validate(15));
ParamValidator<Integer> lt = ParamValidators.lt(10);
Assert.assertFalse(lt.validate(null));
Assert.assertTrue(lt.validate(5));
Assert.assertFalse(lt.validate(10));
Assert.assertFalse(lt.validate(15));
ParamValidator<Integer> ltEq = ParamValidators.ltEq(10);
Assert.assertFalse(ltEq.validate(null));
Assert.assertTrue(ltEq.validate(5));
Assert.assertTrue(ltEq.validate(10));
Assert.assertFalse(ltEq.validate(15));
ParamValidator<Integer> inRangeInclusive = ParamValidators.inRange(5, 15);
Assert.assertFalse(inRangeInclusive.validate(null));
Assert.assertFalse(inRangeInclusive.validate(0));
Assert.assertTrue(inRangeInclusive.validate(5));
Assert.assertTrue(inRangeInclusive.validate(10));
Assert.assertTrue(inRangeInclusive.validate(15));
Assert.assertFalse(inRangeInclusive.validate(20));
ParamValidator<Integer> inRangeExclusive = ParamValidators.inRange(5, 15, false, false);
Assert.assertFalse(inRangeExclusive.validate(null));
Assert.assertFalse(inRangeExclusive.validate(0));
Assert.assertFalse(inRangeExclusive.validate(5));
Assert.assertTrue(inRangeExclusive.validate(10));
Assert.assertFalse(inRangeExclusive.validate(15));
Assert.assertFalse(inRangeExclusive.validate(20));
ParamValidator<Integer> inArray = ParamValidators.inArray(1, 2, 3);
Assert.assertFalse(inArray.validate(null));
Assert.assertTrue(inArray.validate(1));
Assert.assertFalse(inArray.validate(0));
ParamValidator<Integer> notNull = ParamValidators.notNull();
Assert.assertTrue(notNull.validate(5));
Assert.assertFalse(notNull.validate(null));
ParamValidator<Object[]> nonEmptyArray = ParamValidators.nonEmptyArray();
Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"}));
Assert.assertFalse(nonEmptyArray.validate(null));
Assert.assertFalse(nonEmptyArray.validate(new String[0]));
ParamValidator<String[]> isSubArray = ParamValidators.isSubSet("a", "b", "c");
Assert.assertFalse(isSubArray.validate(null));
Assert.assertFalse(isSubArray.validate(new String[] {"c", "v"}));
Assert.assertTrue(isSubArray.validate(new String[] {"a", "b"}));
Assert.assertFalse(isSubArray.validate(new String[] {"e", "v"}));
}
}