blob: bc85b75c62a44025d7c5e04c3061c246c8b7e6fd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation.streaming;
import static org.apache.beam.sdk.metrics.MetricResultsMatchers.attemptedMetricsResult;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineResult;
import org.apache.beam.runners.spark.TestSparkPipelineOptions;
import org.apache.beam.runners.spark.TestSparkRunner;
import org.apache.beam.runners.spark.UsesCheckpointRecovery;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.kafka.KafkaIO;
import org.apache.beam.sdk.io.kafka.serialization.InstantDeserializer;
import org.apache.beam.sdk.io.kafka.serialization.InstantSerializer;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.MetricNameFilter;
import org.apache.beam.sdk.metrics.MetricResult;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsFilter;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.common.serialization.StringDeserializer;
import org.apache.kafka.common.serialization.StringSerializer;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TemporaryFolder;
/**
* Tests DStream recovery from checkpoint.
*
* <p>Runs the pipeline reading from a Kafka backlog with a WM function that will move to infinity
* on a EOF signal. After resuming from checkpoint, a single output (guaranteed by the WM) is
* asserted, along with {@link Metrics} values that are expected to resume from previous count and a
* side-input that is expected to recover as well.
*/
public class ResumeFromCheckpointStreamingTest implements Serializable {
private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER =
new EmbeddedKafkaCluster.EmbeddedZookeeper();
private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER =
new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(), new Properties());
private static final String TOPIC = "kafka_beam_test_topic";
private transient TemporaryFolder temporaryFolder;
@Rule public final transient ReuseSparkContextRule noContextReuse = ReuseSparkContextRule.no();
@BeforeClass
public static void setup() throws IOException {
EMBEDDED_ZOOKEEPER.startup();
EMBEDDED_KAFKA_CLUSTER.startup();
}
@Before
public void init() {
temporaryFolder = new TemporaryFolder();
try {
temporaryFolder.create();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("FutureReturnValueIgnored")
private static void produce(Map<String, Instant> messages) {
Properties producerProps = new Properties();
producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps());
producerProps.put("request.required.acks", 1);
producerProps.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList());
Serializer<String> stringSerializer = new StringSerializer();
Serializer<Instant> instantSerializer = new InstantSerializer();
try (KafkaProducer<String, Instant> kafkaProducer =
new KafkaProducer(producerProps, stringSerializer, instantSerializer)) {
for (Map.Entry<String, Instant> en : messages.entrySet()) {
kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), en.getValue()));
}
}
}
@Test
@Category(UsesCheckpointRecovery.class)
public void testWithResume() throws Exception {
// write to Kafka
produce(
ImmutableMap.of(
"k1", new Instant(100),
"k2", new Instant(200),
"k3", new Instant(300),
"k4", new Instant(400)));
MetricsFilter metricsFilter =
MetricsFilter.builder()
.addNameFilter(MetricNameFilter.inNamespace(ResumeFromCheckpointStreamingTest.class))
.build();
// first run should expect EOT matching the last injected element.
SparkPipelineResult res = run(Optional.of(new Instant(400)), 0);
assertThat(
res.metrics().queryMetrics(metricsFilter).getCounters(),
hasItem(
attemptedMetricsResult(
ResumeFromCheckpointStreamingTest.class.getName(),
"allMessages",
"EOFShallNotPassFn",
4L)));
assertThat(
res.metrics().queryMetrics(metricsFilter).getCounters(),
hasItem(
attemptedMetricsResult(
ResumeFromCheckpointStreamingTest.class.getName(),
"processedMessages",
"EOFShallNotPassFn",
4L)));
// --- between executions:
// - clear state.
clean();
// - write a bit more.
produce(
ImmutableMap.of(
"k5", new Instant(499),
"EOF", new Instant(500) // to be dropped from [0, 500).
));
// recovery should resume from last read offset, and read the second batch of input.
res = runAgain(1);
// assertions 2:
assertThat(
res.metrics().queryMetrics(metricsFilter).getCounters(),
hasItem(
attemptedMetricsResult(
ResumeFromCheckpointStreamingTest.class.getName(),
"processedMessages",
"EOFShallNotPassFn",
5L)));
assertThat(
res.metrics().queryMetrics(metricsFilter).getCounters(),
hasItem(
attemptedMetricsResult(
ResumeFromCheckpointStreamingTest.class.getName(),
"allMessages",
"EOFShallNotPassFn",
6L)));
long successAssertions = 0;
Iterable<MetricResult<Long>> counterResults =
res.metrics()
.queryMetrics(
MetricsFilter.builder()
.addNameFilter(
MetricNameFilter.named(
PAssertWithoutFlatten.class, PAssert.SUCCESS_COUNTER))
.build())
.getCounters();
for (MetricResult<Long> counter : counterResults) {
if (counter.getAttempted() > 0) {
successAssertions++;
}
}
assertThat(
String.format("Expected %d successful assertions, but found %d.", 1L, successAssertions),
successAssertions,
is(1L));
// validate assertion didn't fail.
long failedAssertions = 0;
Iterable<MetricResult<Long>> failCounterResults =
res.metrics()
.queryMetrics(
MetricsFilter.builder()
.addNameFilter(
MetricNameFilter.named(
PAssertWithoutFlatten.class, PAssert.FAILURE_COUNTER))
.build())
.getCounters();
for (MetricResult<Long> counter : failCounterResults) {
if (counter.getAttempted() > 0) {
failedAssertions++;
}
}
assertThat(
String.format("Found %d failed assertions.", failedAssertions), failedAssertions, is(0L));
}
private SparkPipelineResult runAgain(int expectedAssertions) {
// sleep before next run.
Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
return run(Optional.absent(), expectedAssertions);
}
private SparkPipelineResult run(Optional<Instant> stopWatermarkOption, int expectedAssertions) {
KafkaIO.Read<String, Instant> read =
KafkaIO.<String, Instant>read()
.withBootstrapServers(EMBEDDED_KAFKA_CLUSTER.getBrokerList())
.withTopics(Collections.singletonList(TOPIC))
.withKeyDeserializer(StringDeserializer.class)
.withValueDeserializer(InstantDeserializer.class)
.withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest"))
.withTimestampFn(KV::getValue)
.withWatermarkFn(
kv -> {
// at EOF move WM to infinity.
String key = kv.getKey();
Instant instant = kv.getValue();
return "EOF".equals(key) ? BoundedWindow.TIMESTAMP_MAX_VALUE : instant;
});
TestSparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setSparkMaster("local[*]");
options.setCheckpointDurationMillis(options.getBatchIntervalMillis());
options.setExpectedAssertions(expectedAssertions);
options.setRunner(TestSparkRunner.class);
options.setEnableSparkMetricSinks(false);
options.setForceStreaming(true);
options.setCheckpointDir(temporaryFolder.getRoot().getPath());
// timeout is per execution so it can be injected by the caller.
if (stopWatermarkOption.isPresent()) {
options.setStopPipelineWatermark(stopWatermarkOption.get().getMillis());
}
Pipeline p = Pipeline.create(options);
PCollection<String> expectedCol =
p.apply(Create.of(ImmutableList.of("side1", "side2")).withCoder(StringUtf8Coder.of()));
PCollectionView<List<String>> view = expectedCol.apply(View.asList());
PCollection<KV<String, Instant>> kafkaStream = p.apply(read.withoutMetadata());
PCollection<Iterable<String>> grouped =
kafkaStream
.apply(Keys.create())
.apply("EOFShallNotPassFn", ParDo.of(new EOFShallNotPassFn(view)).withSideInputs(view))
.apply(
Window.<String>into(FixedWindows.of(Duration.millis(500)))
.triggering(AfterWatermark.pastEndOfWindow())
.accumulatingFiredPanes()
.withAllowedLateness(Duration.ZERO))
.apply(WithKeys.of(1))
.apply(GroupByKey.create())
.apply(Values.create());
grouped.apply(new PAssertWithoutFlatten<>("k1", "k2", "k3", "k4", "k5"));
return (SparkPipelineResult) p.run();
}
@After
public void clean() {
AggregatorsAccumulator.clear();
MetricsAccumulator.clear();
GlobalWatermarkHolder.clear();
MicrobatchSource.clearCache();
}
@AfterClass
public static void tearDown() {
EMBEDDED_KAFKA_CLUSTER.shutdown();
EMBEDDED_ZOOKEEPER.shutdown();
}
/** A pass-through fn that prevents EOF event from passing. */
private static class EOFShallNotPassFn extends DoFn<String, String> {
final PCollectionView<List<String>> view;
private final Counter aggregator =
Metrics.counter(ResumeFromCheckpointStreamingTest.class, "processedMessages");
final Counter counter = Metrics.counter(ResumeFromCheckpointStreamingTest.class, "allMessages");
private EOFShallNotPassFn(PCollectionView<List<String>> view) {
this.view = view;
}
@ProcessElement
public void process(ProcessContext c) {
String element = c.element();
// assert that side input is passed correctly before/after resuming from checkpoint.
assertThat(c.sideInput(view), containsInAnyOrder("side1", "side2"));
counter.inc();
if (!"EOF".equals(element)) {
aggregator.inc();
c.output(c.element());
}
}
}
/**
* A custom PAssert that avoids using {@link org.apache.beam.sdk.transforms.Flatten} until
* BEAM-1444 is resolved.
*/
private static class PAssertWithoutFlatten<T>
extends PTransform<PCollection<Iterable<T>>, PDone> {
private final T[] expected;
@SafeVarargs
private PAssertWithoutFlatten(T... expected) {
this.expected = expected;
}
@Override
public PDone expand(PCollection<Iterable<T>> input) {
input.apply(ParDo.of(new AssertDoFn<>(expected)));
return PDone.in(input.getPipeline());
}
private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
private final Counter success =
Metrics.counter(PAssertWithoutFlatten.class, PAssert.SUCCESS_COUNTER);
private final Counter failure =
Metrics.counter(PAssertWithoutFlatten.class, PAssert.FAILURE_COUNTER);
private final T[] expected;
AssertDoFn(T[] expected) {
this.expected = expected;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
try {
assertThat(c.element(), containsInAnyOrder(expected));
success.inc();
} catch (Throwable t) {
failure.inc();
throw t;
}
}
}
}
}