blob: ad9c65d5a2a5e4be5d435b205babf55e8adbd7a5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.typeserializerupgrade;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateBackendLoader;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.StateMigrationException;
import org.apache.flink.util.TestLogger;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Tests the state migration behaviour when the underlying POJO type changes and one tries to
* recover from old state.
*/
@RunWith(Parameterized.class)
public class PojoSerializerUpgradeTest extends TestLogger {
@Parameterized.Parameters(name = "StateBackend: {0}")
public static Collection<String> parameters() {
return Arrays.asList(
StateBackendLoader.MEMORY_STATE_BACKEND_NAME,
StateBackendLoader.FS_STATE_BACKEND_NAME,
StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME);
}
@ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
private StateBackend stateBackend;
public PojoSerializerUpgradeTest(String backendType)
throws IOException, DynamicCodeLoadingException {
Configuration config = new Configuration();
config.setString(StateBackendOptions.STATE_BACKEND, backendType);
config.setString(
CheckpointingOptions.CHECKPOINTS_DIRECTORY,
temporaryFolder.newFolder().toURI().toString());
stateBackend =
StateBackendLoader.loadStateBackendFromConfig(
config, Thread.currentThread().getContextClassLoader(), null);
}
private static final String POJO_NAME = "Pojo";
private static final String SOURCE_A =
"import java.util.Objects;"
+ "public class Pojo { "
+ "private long a; "
+ "private String b; "
+ "public long getA() { return a;} "
+ "public void setA(long value) { a = value; }"
+ "public String getB() { return b; }"
+ "public void setB(String value) { b = value; }"
+ "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
+ "@Override public int hashCode() { return Objects.hash(a, b); } "
+ "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";
// changed order of fields which should be recoverable
private static final String SOURCE_B =
"import java.util.Objects;"
+ "public class Pojo { "
+ "private String b; "
+ "private long a; "
+ "public long getA() { return a;} "
+ "public void setA(long value) { a = value; }"
+ "public String getB() { return b; }"
+ "public void setB(String value) { b = value; }"
+ "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
+ "@Override public int hashCode() { return Objects.hash(a, b); } "
+ "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";
// changed type of a field which should not be recoverable
private static final String SOURCE_C =
"import java.util.Objects;"
+ "public class Pojo { "
+ "private double a; "
+ "private String b; "
+ "public double getA() { return a;} "
+ "public void setA(double value) { a = value; }"
+ "public String getB() { return b; }"
+ "public void setB(String value) { b = value; }"
+ "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
+ "@Override public int hashCode() { return Objects.hash(a, b); } "
+ "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";
// additional field which should not be recoverable
private static final String SOURCE_D =
"import java.util.Objects;"
+ "public class Pojo { "
+ "private long a; "
+ "private String b; "
+ "private double c; "
+ "public long getA() { return a;} "
+ "public void setA(long value) { a = value; }"
+ "public String getB() { return b; }"
+ "public void setB(String value) { b = value; }"
+ "public double getC() { return c; } "
+ "public void setC(double value) { c = value; }"
+ "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b) && c == other.c;} else { return false; }}"
+ "@Override public int hashCode() { return Objects.hash(a, b, c); } "
+ "@Override public String toString() {return \"(\" + a + \", \" + b + \", \" + c + \")\";}}";
// missing field which should not be recoverable
private static final String SOURCE_E =
"import java.util.Objects;"
+ "public class Pojo { "
+ "private long a; "
+ "public long getA() { return a;} "
+ "public void setA(long value) { a = value; }"
+ "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a;} else { return false; }}"
+ "@Override public int hashCode() { return Objects.hash(a); } "
+ "@Override public String toString() {return \"(\" + a + \")\";}}";
/** We should be able to handle a changed field order of a POJO as keyed state. */
@Test
public void testChangedFieldOrderWithKeyedState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, true);
}
/** We should be able to handle a changed field order of a POJO as operator state. */
@Test
public void testChangedFieldOrderWithOperatorState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, false);
}
/** Changing field types of a POJO as keyed state should require a state migration. */
@Test
public void testChangedFieldTypesWithKeyedState() throws Exception {
try {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, true);
fail("Expected a state migration exception.");
} catch (Exception e) {
if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
// StateMigrationException expected
} else {
throw e;
}
}
}
/** Changing field types of a POJO as operator state should require a state migration. */
@Test
public void testChangedFieldTypesWithOperatorState() throws Exception {
try {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, false);
fail("Expected a state migration exception.");
} catch (Exception e) {
if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
// StateMigrationException expected
} else {
throw e;
}
}
}
/** Adding fields to a POJO as keyed state should succeed. */
@Test
public void testAdditionalFieldWithKeyedState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, true);
}
/** Adding fields to a POJO as operator state should succeed. */
@Test
public void testAdditionalFieldWithOperatorState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, false);
}
/** Removing fields from a POJO as keyed state should succeed. */
@Test
public void testMissingFieldWithKeyedState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, true);
}
/** Removing fields from a POJO as operator state should succeed. */
@Test
public void testMissingFieldWithOperatorState() throws Exception {
testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, false);
}
private void testPojoSerializerUpgrade(
String classSourceA, String classSourceB, boolean hasBField, boolean isKeyedState)
throws Exception {
final Configuration taskConfiguration = new Configuration();
final ExecutionConfig executionConfig = new ExecutionConfig();
final KeySelector<Long, Long> keySelector = new IdentityKeySelector<>();
final Collection<Long> inputs = Arrays.asList(1L, 2L, 45L, 67L, 1337L);
// run the program with classSourceA
File rootPath = temporaryFolder.newFolder();
File sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceA);
compileClass(sourceFile);
final ClassLoader classLoader =
URLClassLoader.newInstance(
new URL[] {rootPath.toURI().toURL()},
Thread.currentThread().getContextClassLoader());
OperatorSubtaskState stateHandles =
runOperator(
taskConfiguration,
executionConfig,
new StreamMap<>(new StatefulMapper(isKeyedState, false, hasBField)),
keySelector,
isKeyedState,
stateBackend,
classLoader,
null,
inputs);
// run the program with classSourceB
rootPath = temporaryFolder.newFolder();
sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceB);
compileClass(sourceFile);
final ClassLoader classLoaderB =
URLClassLoader.newInstance(
new URL[] {rootPath.toURI().toURL()},
Thread.currentThread().getContextClassLoader());
runOperator(
taskConfiguration,
executionConfig,
new StreamMap<>(new StatefulMapper(isKeyedState, true, hasBField)),
keySelector,
isKeyedState,
stateBackend,
classLoaderB,
stateHandles,
inputs);
}
private OperatorSubtaskState runOperator(
Configuration taskConfiguration,
ExecutionConfig executionConfig,
OneInputStreamOperator<Long, Long> operator,
KeySelector<Long, Long> keySelector,
boolean isKeyedState,
StateBackend stateBackend,
ClassLoader classLoader,
OperatorSubtaskState operatorSubtaskState,
Iterable<Long> input)
throws Exception {
try (final MockEnvironment environment =
new MockEnvironmentBuilder()
.setTaskName("test task")
.setManagedMemorySize(32 * 1024)
.setInputSplitProvider(new MockInputSplitProvider())
.setBufferSize(256)
.setTaskConfiguration(taskConfiguration)
.setExecutionConfig(executionConfig)
.setMaxParallelism(16)
.setUserCodeClassLoader(classLoader)
.build()) {
OneInputStreamOperatorTestHarness<Long, Long> harness = null;
try {
if (isKeyedState) {
harness =
new KeyedOneInputStreamOperatorTestHarness<>(
operator,
keySelector,
BasicTypeInfo.LONG_TYPE_INFO,
environment);
} else {
harness =
new OneInputStreamOperatorTestHarness<>(
operator, LongSerializer.INSTANCE, environment);
}
harness.setStateBackend(stateBackend);
harness.setup();
harness.initializeState(operatorSubtaskState);
harness.open();
long timestamp = 0L;
for (Long value : input) {
harness.processElement(value, timestamp++);
}
long checkpointId = 1L;
long checkpointTimestamp = timestamp + 1L;
return harness.snapshot(checkpointId, checkpointTimestamp);
} finally {
IOUtils.closeQuietly(harness);
}
}
}
private static File writeSourceFile(File root, String name, String source) throws IOException {
File sourceFile = new File(root, name);
sourceFile.getParentFile().mkdirs();
try (FileWriter writer = new FileWriter(sourceFile)) {
writer.write(source);
}
return sourceFile;
}
private static int compileClass(File sourceFile) {
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
return compiler.run(null, null, null, "-proc:none", sourceFile.getPath());
}
private static final class StatefulMapper extends RichMapFunction<Long, Long>
implements CheckpointedFunction {
private static final long serialVersionUID = -520490739059396832L;
private final boolean keyed;
private final boolean verify;
private final boolean hasBField;
// keyed states
private transient ValueState<Object> keyedValueState;
private transient ListState<Object> keyedListState;
private transient ReducingState<Object> keyedReducingState;
// operator states
private transient ListState<Object> partitionableListState;
private transient ListState<Object> unionListState;
private transient Class<?> pojoClass;
private transient Field fieldA;
private transient Field fieldB;
StatefulMapper(boolean keyed, boolean verify, boolean hasBField) {
this.keyed = keyed;
this.verify = verify;
this.hasBField = hasBField;
}
@Override
public Long map(Long value) throws Exception {
Object pojo = pojoClass.newInstance();
fieldA.set(pojo, value);
if (hasBField) {
fieldB.set(pojo, value + "");
}
if (verify) {
if (keyed) {
assertEquals(pojo, keyedValueState.value());
Iterator<Object> listIterator = keyedListState.get().iterator();
boolean elementFound = false;
while (listIterator.hasNext()) {
elementFound |= pojo.equals(listIterator.next());
}
assertTrue(elementFound);
assertEquals(pojo, keyedReducingState.get());
} else {
boolean elementFound = false;
Iterator<Object> listIterator = partitionableListState.get().iterator();
while (listIterator.hasNext()) {
elementFound |= pojo.equals(listIterator.next());
}
assertTrue(elementFound);
elementFound = false;
listIterator = unionListState.get().iterator();
while (listIterator.hasNext()) {
elementFound |= pojo.equals(listIterator.next());
}
assertTrue(elementFound);
}
} else {
if (keyed) {
keyedValueState.update(pojo);
keyedListState.add(pojo);
keyedReducingState.add(pojo);
} else {
partitionableListState.add(pojo);
unionListState.add(pojo);
}
}
return value;
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {}
@SuppressWarnings("unchecked")
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
pojoClass = getRuntimeContext().getUserCodeClassLoader().loadClass(POJO_NAME);
fieldA = pojoClass.getDeclaredField("a");
fieldA.setAccessible(true);
if (hasBField) {
fieldB = pojoClass.getDeclaredField("b");
fieldB.setAccessible(true);
}
if (keyed) {
keyedValueState =
context.getKeyedStateStore()
.getState(
new ValueStateDescriptor<>(
"keyedValueState", (Class<Object>) pojoClass));
keyedListState =
context.getKeyedStateStore()
.getListState(
new ListStateDescriptor<>(
"keyedListState", (Class<Object>) pojoClass));
ReduceFunction<Object> reduceFunction = new FirstValueReducer<>();
keyedReducingState =
context.getKeyedStateStore()
.getReducingState(
new ReducingStateDescriptor<>(
"keyedReducingState",
reduceFunction,
(Class<Object>) pojoClass));
} else {
partitionableListState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>(
"partitionableListState",
(Class<Object>) pojoClass));
unionListState =
context.getOperatorStateStore()
.getUnionListState(
new ListStateDescriptor<>(
"unionListState", (Class<Object>) pojoClass));
}
}
}
private static final class FirstValueReducer<T> implements ReduceFunction<T> {
private static final long serialVersionUID = -9222976423336835926L;
@Override
public T reduce(T value1, T value2) throws Exception {
return value1;
}
}
private static final class IdentityKeySelector<T> implements KeySelector<T, T> {
private static final long serialVersionUID = -3263628393881929147L;
@Override
public T getKey(T value) throws Exception {
return value;
}
}
}