blob: ecaa3c3cff4f3c87782ce9e14e851c8719b1c525 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.state.api;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/** IT case for reading state. */
public abstract class SavepointReaderITTestBase extends AbstractTestBase {
static final String UID = "stateful-operator";
static final String LIST_NAME = "list";
static final String UNION_NAME = "union";
static final String BROADCAST_NAME = "broadcast";
private final ListStateDescriptor<Integer> list;
private final ListStateDescriptor<Integer> union;
private final MapStateDescriptor<Integer, String> broadcast;
SavepointReaderITTestBase(
ListStateDescriptor<Integer> list,
ListStateDescriptor<Integer> union,
MapStateDescriptor<Integer, String> broadcast) {
this.list = list;
this.union = union;
this.broadcast = broadcast;
}
@Test
public void testOperatorStateInputFormat() throws Exception {
StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment();
streamEnv.setParallelism(4);
DataStream<Integer> data = streamEnv.addSource(new SavepointSource()).rebalance();
data.connect(data.broadcast(broadcast))
.process(new StatefulOperator(list, union, broadcast))
.uid(UID)
.addSink(new DiscardingSink<>());
JobGraph jobGraph = streamEnv.getStreamGraph().getJobGraph();
String savepoint = takeSavepoint(jobGraph);
ExecutionEnvironment batchEnv = ExecutionEnvironment.getExecutionEnvironment();
verifyListState(savepoint, batchEnv);
verifyUnionState(savepoint, batchEnv);
verifyBroadcastState(savepoint, batchEnv);
}
abstract DataSet<Integer> readListState(ExistingSavepoint savepoint) throws IOException;
abstract DataSet<Integer> readUnionState(ExistingSavepoint savepoint) throws IOException;
abstract DataSet<Tuple2<Integer, String>> readBroadcastState(ExistingSavepoint savepoint)
throws IOException;
private void verifyListState(String path, ExecutionEnvironment batchEnv) throws Exception {
ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
List<Integer> listResult = readListState(savepoint).collect();
listResult.sort(Comparator.naturalOrder());
Assert.assertEquals(
"Unexpected elements read from list state",
SavepointSource.getElements(),
listResult);
}
private void verifyUnionState(String path, ExecutionEnvironment batchEnv) throws Exception {
ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
List<Integer> unionResult = readUnionState(savepoint).collect();
unionResult.sort(Comparator.naturalOrder());
Assert.assertEquals(
"Unexpected elements read from union state",
SavepointSource.getElements(),
unionResult);
}
private void verifyBroadcastState(String path, ExecutionEnvironment batchEnv) throws Exception {
ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend());
List<Tuple2<Integer, String>> broadcastResult = readBroadcastState(savepoint).collect();
List<Integer> broadcastStateKeys =
broadcastResult.stream()
.map(entry -> entry.f0)
.sorted(Comparator.naturalOrder())
.collect(Collectors.toList());
List<String> broadcastStateValues =
broadcastResult.stream()
.map(entry -> entry.f1)
.sorted(Comparator.naturalOrder())
.collect(Collectors.toList());
Assert.assertEquals(
"Unexpected element in broadcast state keys",
SavepointSource.getElements(),
broadcastStateKeys);
Assert.assertEquals(
"Unexpected element in broadcast state values",
SavepointSource.getElements().stream()
.map(Object::toString)
.sorted()
.collect(Collectors.toList()),
broadcastStateValues);
}
private String takeSavepoint(JobGraph jobGraph) throws Exception {
SavepointSource.initializeForTest();
ClusterClient<?> client = miniClusterResource.getClusterClient();
JobID jobId = jobGraph.getJobID();
Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5));
String dirPath = getTempDirPath(new AbstractID().toHexString());
try {
JobID jobID = client.submitJob(jobGraph).get();
boolean finished = false;
while (deadline.hasTimeLeft()) {
if (SavepointSource.isFinished()) {
finished = true;
break;
}
try {
Thread.sleep(2L);
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
}
}
if (!finished) {
Assert.fail("Failed to initialize state within deadline");
}
CompletableFuture<String> path = client.triggerSavepoint(jobID, dirPath);
return path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
} finally {
client.cancel(jobId).get();
}
}
private static class SavepointSource implements SourceFunction<Integer> {
private static volatile boolean finished;
private volatile boolean running = true;
private static final Integer[] elements = {1, 2, 3};
@Override
public void run(SourceContext<Integer> ctx) {
synchronized (ctx.getCheckpointLock()) {
for (Integer element : elements) {
ctx.collect(element);
}
finished = true;
}
while (running) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
// ignore
}
}
}
@Override
public void cancel() {
running = false;
}
private static void initializeForTest() {
finished = false;
}
private static boolean isFinished() {
return finished;
}
private static List<Integer> getElements() {
return Arrays.asList(elements);
}
}
private static class StatefulOperator extends BroadcastProcessFunction<Integer, Integer, Void>
implements CheckpointedFunction {
private final ListStateDescriptor<Integer> list;
private final ListStateDescriptor<Integer> union;
private final MapStateDescriptor<Integer, String> broadcast;
private List<Integer> elements;
private ListState<Integer> listState;
private ListState<Integer> unionState;
private StatefulOperator(
ListStateDescriptor<Integer> list,
ListStateDescriptor<Integer> union,
MapStateDescriptor<Integer, String> broadcast) {
this.list = list;
this.union = union;
this.broadcast = broadcast;
}
@Override
public void open(Configuration parameters) {
elements = new ArrayList<>();
}
@Override
public void processElement(Integer value, ReadOnlyContext ctx, Collector<Void> out) {
elements.add(value);
}
@Override
public void processBroadcastElement(Integer value, Context ctx, Collector<Void> out)
throws Exception {
ctx.getBroadcastState(broadcast).put(value, value.toString());
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
listState.clear();
listState.addAll(elements);
unionState.clear();
unionState.addAll(elements);
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
listState = context.getOperatorStateStore().getListState(list);
unionState = context.getOperatorStateStore().getUnionListState(union);
}
}
}