blob: 1d9a21a24e82a2d5c0d37bcf7771997c9381f8d9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.checkpointing;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamGroupedReduceOperator;
import org.apache.flink.shaded.guava30.com.google.common.collect.EvictingQueue;
import org.junit.Assert;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.Random;
/**
* Integration test ensuring that the persistent state defined by the implementations of {@link
* AbstractUdfStreamOperator} is correctly restored in case of recovery from a failure.
*
* <p>The topology currently tests the proper behaviour of the {@link StreamGroupedReduceOperator}
* operator.
*/
@SuppressWarnings("serial")
public class UdfStreamOperatorCheckpointingITCase extends StreamFaultToleranceTestBase {
private static final long NUM_INPUT = 500_000L;
private static final int NUM_OUTPUT = 1_000;
/**
* Assembles a stream of a grouping field and some long data. Applies reduce functions on this
* stream.
*/
@Override
public void testProgram(StreamExecutionEnvironment env) {
// base stream
KeyedStream<Tuple2<Integer, Long>, Tuple> stream =
env.addSource(new StatefulMultipleSequence()).keyBy(0);
stream
// testing built-in aggregate
.min(1)
// failure generation
.map(new OnceFailingIdentityMapFunction(NUM_INPUT))
.keyBy(0)
.addSink(new MinEvictingQueueSink());
stream
// testing UDF reducer
.reduce(
new ReduceFunction<Tuple2<Integer, Long>>() {
@Override
public Tuple2<Integer, Long> reduce(
Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2)
throws Exception {
return Tuple2.of(value1.f0, value1.f1 + value2.f1);
}
})
.keyBy(0)
.addSink(new SumEvictingQueueSink());
}
@Override
public void postSubmit() {
// Note that these checks depend on the ordering of the input
// Checking the result of the built-in aggregate
for (int i = 0; i < PARALLELISM; i++) {
for (Long value : MinEvictingQueueSink.queues[i]) {
Assert.assertTrue("Value different from 1 found, was " + value + ".", value == 1);
}
}
// Checking the result of the UDF reducer
for (int i = 0; i < PARALLELISM; i++) {
long prevCount = NUM_INPUT - NUM_OUTPUT;
long sum = prevCount * (prevCount + 1) / 2;
while (!SumEvictingQueueSink.queues[i].isEmpty()) {
sum += ++prevCount;
Long value = SumEvictingQueueSink.queues[i].remove();
Assert.assertTrue(
"Unexpected reduce value " + value + " instead of " + sum + ".",
value == sum);
}
}
}
// --------------------------------------------------------------------------------------------
// Custom Functions
// --------------------------------------------------------------------------------------------
/**
* Produces a sequence multiple times for each parallelism instance of downstream operators,
* augmented by the designated parallel subtaskId. The source is not parallel to ensure order.
*/
private static class StatefulMultipleSequence extends RichSourceFunction<Tuple2<Integer, Long>>
implements ListCheckpointed<Long> {
private long count;
@Override
public void run(SourceContext<Tuple2<Integer, Long>> ctx) throws Exception {
Object lock = ctx.getCheckpointLock();
while (count < NUM_INPUT) {
synchronized (lock) {
for (int i = 0; i < PARALLELISM; i++) {
ctx.collect(Tuple2.of(i, count + 1));
}
count++;
}
}
}
@Override
public void cancel() {}
@Override
public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(this.count);
}
@Override
public void restoreState(List<Long> state) throws Exception {
if (state.isEmpty() || state.size() > 1) {
throw new RuntimeException(
"Test failed due to unexpected recovered state size " + state.size());
}
this.count = state.get(0);
}
}
/** Mapper that causes one failure between seeing 40% to 70% of the records. */
private static class OnceFailingIdentityMapFunction
extends RichMapFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>>
implements ListCheckpointed<Long> {
private static volatile boolean hasFailed = false;
private final long numElements;
private long failurePos;
private long count;
public OnceFailingIdentityMapFunction(long numElements) {
this.numElements = numElements;
}
@Override
public void open(Configuration parameters) throws Exception {
long failurePosMin =
(long) (0.4 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
long failurePosMax =
(long) (0.7 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
failurePos =
(new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
}
@Override
public Tuple2<Integer, Long> map(Tuple2<Integer, Long> value) throws Exception {
if (!hasFailed && count >= failurePos) {
hasFailed = true;
throw new Exception("Test Failure");
}
count++;
return value;
}
@Override
public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception {
return Collections.singletonList(count);
}
@Override
public void restoreState(List<Long> state) throws Exception {
if (!state.isEmpty()) {
count = state.get(0);
}
}
}
/**
* Sink that emits the output to an evicting queue storing the last {@link #NUM_OUTPUT}
* elements. A separate queue is initiated for each group, apply a grouping prior to this
* operator to avoid parallel access of the queues.
*/
private static class MinEvictingQueueSink implements SinkFunction<Tuple2<Integer, Long>> {
public static Queue<Long>[] queues = new Queue[PARALLELISM];
@Override
public void invoke(Tuple2<Integer, Long> value) throws Exception {
if (queues[value.f0] == null) {
queues[value.f0] = EvictingQueue.create(NUM_OUTPUT);
}
queues[value.f0].add(value.f1);
}
}
/**
* Sink that emits the output to an evicting queue storing the last {@link #NUM_OUTPUT}
* elements. A separate queue is initiated for each group, apply a grouping prior to this
* operator to avoid parallel access of the queues.
*/
private static class SumEvictingQueueSink implements SinkFunction<Tuple2<Integer, Long>> {
public static Queue<Long>[] queues = new Queue[PARALLELISM];
@Override
public void invoke(Tuple2<Integer, Long> value) throws Exception {
if (queues[value.f0] == null) {
queues[value.f0] = EvictingQueue.create(NUM_OUTPUT);
}
queues[value.f0].add(value.f1);
}
}
}