blob: 28e9e1c3f2f119f27a634744f2d7e9b466789120 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.state.api;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
import org.apache.flink.state.api.functions.StateBootstrapFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.util.StreamCollector;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.apache.flink.util.SerializedThrowable;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
/** IT test for writing savepoints. */
public class SavepointWriterITCase extends AbstractTestBase {
private static final String ACCOUNT_UID = "accounts";
private static final String CURRENCY_UID = "currency";
private static final String MODIFY_UID = "numbers";
private static final MapStateDescriptor<String, Double> descriptor =
new MapStateDescriptor<>("currency-rate", Types.STRING, Types.DOUBLE);
private static final Collection<Account> accounts =
Arrays.asList(new Account(1, 100.0), new Account(2, 100.0), new Account(3, 100.0));
private static final Collection<CurrencyRate> currencyRates =
Arrays.asList(new CurrencyRate("USD", 1.0), new CurrencyRate("EUR", 1.3));
@Rule public StreamCollector collector = new StreamCollector();
@Test
public void testDefaultStateBackend() throws Exception {
testStateBootstrapAndModification(null);
}
@Test
public void testHashMapStateBackend() throws Exception {
testStateBootstrapAndModification(new HashMapStateBackend());
}
@Test
public void testEmbeddedRocksDBStateBackend() throws Exception {
testStateBootstrapAndModification(new EmbeddedRocksDBStateBackend());
}
public void testStateBootstrapAndModification(StateBackend backend) throws Exception {
final String savepointPath = getTempDirPath(new AbstractID().toHexString());
bootstrapState(backend, savepointPath);
validateBootstrap(backend, savepointPath);
final String modifyPath = getTempDirPath(new AbstractID().toHexString());
modifySavepoint(backend, savepointPath, modifyPath);
validateModification(backend, modifyPath);
}
private void bootstrapState(StateBackend backend, String savepointPath) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setRuntimeMode(RuntimeExecutionMode.AUTOMATIC);
StateBootstrapTransformation<Account> transformation =
OperatorTransformation.bootstrapWith(env.fromCollection(accounts))
.keyBy(acc -> acc.id)
.transform(new AccountBootstrapper());
StateBootstrapTransformation<CurrencyRate> broadcastTransformation =
OperatorTransformation.bootstrapWith(env.fromCollection(currencyRates))
.transform(new CurrencyBootstrapFunction());
SavepointWriter writer =
backend == null
? SavepointWriter.newSavepoint(128)
: SavepointWriter.newSavepoint(backend, 128);
writer.withOperator(ACCOUNT_UID, transformation)
.withOperator(CURRENCY_UID, broadcastTransformation)
.write(savepointPath);
env.execute("Bootstrap");
}
private void validateBootstrap(StateBackend backend, String savepointPath) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
if (backend != null) {
env.setStateBackend(backend);
}
DataStream<Account> stream =
env.fromCollection(accounts)
.keyBy(acc -> acc.id)
.flatMap(new UpdateAndGetAccount())
.uid(ACCOUNT_UID);
CompletableFuture<Collection<Account>> results = collector.collect(stream);
env.fromCollection(currencyRates)
.connect(env.fromCollection(currencyRates).broadcast(descriptor))
.process(new CurrencyValidationFunction())
.uid(CURRENCY_UID)
.addSink(new DiscardingSink<>());
JobGraph jobGraph = env.getStreamGraph().getJobGraph();
jobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath, false));
ClusterClient<?> client = MINI_CLUSTER_RESOURCE.getClusterClient();
Optional<SerializedThrowable> serializedThrowable =
client.submitJob(jobGraph)
.thenCompose(client::requestJobResult)
.get()
.getSerializedThrowable();
serializedThrowable.ifPresent(
t -> {
throw new AssertionError("Unexpected exception during bootstrapping", t);
});
Assert.assertEquals("Unexpected output", 3, results.get().size());
}
private void modifySavepoint(StateBackend backend, String savepointPath, String modifyPath)
throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setRuntimeMode(RuntimeExecutionMode.AUTOMATIC);
StateBootstrapTransformation<Integer> transformation =
OperatorTransformation.bootstrapWith(env.fromElements(1, 2, 3))
.transform(new ModifyProcessFunction());
SavepointWriter writer =
backend == null
? SavepointWriter.fromExistingSavepoint(savepointPath)
: SavepointWriter.fromExistingSavepoint(savepointPath, backend);
writer.removeOperator(CURRENCY_UID)
.withOperator(MODIFY_UID, transformation)
.write(modifyPath);
env.execute("Modifying");
}
private void validateModification(StateBackend backend, String savepointPath) throws Exception {
StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment();
if (backend != null) {
sEnv.setStateBackend(backend);
}
DataStream<Account> stream =
sEnv.fromCollection(accounts)
.keyBy(acc -> acc.id)
.flatMap(new UpdateAndGetAccount())
.uid(ACCOUNT_UID);
CompletableFuture<Collection<Account>> results = collector.collect(stream);
stream.map(acc -> acc.id)
.map(new StatefulOperator())
.uid(MODIFY_UID)
.addSink(new DiscardingSink<>());
JobGraph jobGraph = sEnv.getStreamGraph().getJobGraph();
jobGraph.setSavepointRestoreSettings(
SavepointRestoreSettings.forPath(savepointPath, false));
ClusterClient<?> client = MINI_CLUSTER_RESOURCE.getClusterClient();
Optional<SerializedThrowable> serializedThrowable =
client.submitJob(jobGraph)
.thenCompose(client::requestJobResult)
.get()
.getSerializedThrowable();
Assert.assertFalse(serializedThrowable.isPresent());
Assert.assertEquals("Unexpected output", 3, results.get().size());
}
/** A simple pojo. */
@SuppressWarnings("WeakerAccess")
public static class Account {
Account(int id, double amount) {
this.id = id;
this.amount = amount;
this.timestamp = 1000L;
}
public int id;
public double amount;
public long timestamp;
@Override
public boolean equals(Object obj) {
return obj instanceof Account
&& ((Account) obj).id == id
&& ((Account) obj).amount == amount;
}
@Override
public int hashCode() {
return Objects.hash(id, amount);
}
}
/** A simple pojo. */
@SuppressWarnings("WeakerAccess")
public static class CurrencyRate {
public String currency;
public Double rate;
CurrencyRate(String currency, double rate) {
this.currency = currency;
this.rate = rate;
}
@Override
public boolean equals(Object obj) {
return obj instanceof CurrencyRate
&& ((CurrencyRate) obj).currency.equals(currency)
&& ((CurrencyRate) obj).rate.equals(rate);
}
@Override
public int hashCode() {
return Objects.hash(currency, rate);
}
}
/** A savepoint writer function. */
public static class AccountBootstrapper extends KeyedStateBootstrapFunction<Integer, Account> {
ValueState<Double> state;
@Override
public void open(Configuration parameters) {
ValueStateDescriptor<Double> descriptor =
new ValueStateDescriptor<>("total", Types.DOUBLE);
state = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(Account value, Context ctx) throws Exception {
state.update(value.amount);
}
}
/** A streaming function bootstrapped off the state. */
public static class UpdateAndGetAccount extends RichFlatMapFunction<Account, Account> {
ValueState<Double> state;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
ValueStateDescriptor<Double> descriptor =
new ValueStateDescriptor<>("total", Types.DOUBLE);
state = getRuntimeContext().getState(descriptor);
}
@Override
public void flatMap(Account value, Collector<Account> out) throws Exception {
Double current = state.value();
if (current != null) {
value.amount += current;
}
state.update(value.amount);
out.collect(value);
}
}
/** A bootstrap function. */
public static class ModifyProcessFunction extends StateBootstrapFunction<Integer> {
List<Integer> numbers;
ListState<Integer> state;
@Override
public void open(Configuration parameters) {
numbers = new ArrayList<>();
}
@Override
public void processElement(Integer value, Context ctx) {
numbers.add(value);
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
state.clear();
state.addAll(numbers);
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
state =
context.getOperatorStateStore()
.getUnionListState(new ListStateDescriptor<>("numbers", Types.INT));
}
}
/** A streaming function bootstrapped off the state. */
public static class StatefulOperator extends RichMapFunction<Integer, Integer>
implements CheckpointedFunction {
List<Integer> numbers;
ListState<Integer> state;
@Override
public void open(Configuration parameters) {
numbers = new ArrayList<>();
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
state.clear();
state.addAll(numbers);
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
state =
context.getOperatorStateStore()
.getUnionListState(new ListStateDescriptor<>("numbers", Types.INT));
if (context.isRestored()) {
Set<Integer> expected = new HashSet<>();
expected.add(1);
expected.add(2);
expected.add(3);
for (Integer number : state.get()) {
Assert.assertTrue("Duplicate state", expected.contains(number));
expected.remove(number);
}
Assert.assertTrue(
"Failed to bootstrap all state elements: "
+ Arrays.toString(expected.toArray()),
expected.isEmpty());
}
}
@Override
public Integer map(Integer value) {
return null;
}
}
/** A broadcast bootstrap function. */
public static class CurrencyBootstrapFunction
extends BroadcastStateBootstrapFunction<CurrencyRate> {
@Override
public void processElement(CurrencyRate value, Context ctx) throws Exception {
ctx.getBroadcastState(descriptor).put(value.currency, value.rate);
}
}
/** Checks the restored broadcast state. */
public static class CurrencyValidationFunction
extends BroadcastProcessFunction<CurrencyRate, CurrencyRate, Void> {
@Override
public void processElement(CurrencyRate value, ReadOnlyContext ctx, Collector<Void> out)
throws Exception {
Assert.assertEquals(
"Incorrect currency rate",
value.rate,
ctx.getBroadcastState(descriptor).get(value.currency),
0.0001);
}
@Override
public void processBroadcastElement(CurrencyRate value, Context ctx, Collector<Void> out) {
// ignore
}
}
}