blob: f4f8b18ffc548964a0c1dffeb3c99538b85765fa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.pubsub;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.io.gcp.pubsub.TestPubsub.createTopicName;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SubscriptionPath;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestPipelineOptions;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Test rule which observes elements of the {@link PCollection} and checks whether they match the
* success criteria.
*
* <p>Uses a random temporary Pubsub topic for synchronization.
*/
public class TestPubsubSignal implements TestRule {
private static final Logger LOG = LoggerFactory.getLogger(TestPubsubSignal.class);
private static final String RESULT_TOPIC_NAME = "result";
private static final String RESULT_SUCCESS_MESSAGE = "SUCCESS";
private static final String START_TOPIC_NAME = "start";
private static final String START_SIGNAL_MESSAGE = "START SIGNAL";
private static final String NO_ID_ATTRIBUTE = null;
private static final String NO_TIMESTAMP_ATTRIBUTE = null;
PubsubClient pubsub;
private TestPubsubOptions pipelineOptions;
private @Nullable TopicPath resultTopicPath = null;
private @Nullable TopicPath startTopicPath = null;
/**
* Creates an instance of this rule.
*
* <p>Loads GCP configuration from {@link TestPipelineOptions}.
*/
public static TestPubsubSignal create() {
TestPubsubOptions options = TestPipeline.testingPipelineOptions().as(TestPubsubOptions.class);
return new TestPubsubSignal(options);
}
private TestPubsubSignal(TestPubsubOptions pipelineOptions) {
this.pipelineOptions = pipelineOptions;
}
@Override
public Statement apply(Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
if (TestPubsubSignal.this.pubsub != null) {
throw new AssertionError(
"Pubsub client was not shutdown in previous test. "
+ "Topic path is'"
+ resultTopicPath
+ "'. "
+ "Current test: "
+ description.getDisplayName());
}
try {
initializePubsub(description);
base.evaluate();
} finally {
tearDown();
}
}
};
}
private void initializePubsub(Description description) throws IOException {
pubsub =
PubsubGrpcClient.FACTORY.newClient(
NO_TIMESTAMP_ATTRIBUTE, NO_ID_ATTRIBUTE, pipelineOptions);
// Example topic name:
// integ-test-TestClassName-testMethodName-2018-12-11-23-32-333-<random-long>-result
TopicPath resultTopicPathTmp =
PubsubClient.topicPathFromName(
pipelineOptions.getProject(), createTopicName(description, RESULT_TOPIC_NAME));
TopicPath startTopicPathTmp =
PubsubClient.topicPathFromName(
pipelineOptions.getProject(), createTopicName(description, START_TOPIC_NAME));
pubsub.createTopic(resultTopicPathTmp);
pubsub.createTopic(startTopicPathTmp);
// Set these after successful creation; this signals that they need teardown
resultTopicPath = resultTopicPathTmp;
startTopicPath = startTopicPathTmp;
}
private void tearDown() throws IOException {
if (pubsub == null) {
return;
}
try {
if (resultTopicPath != null) {
pubsub.deleteTopic(resultTopicPath);
}
} finally {
pubsub.close();
pubsub = null;
resultTopicPath = null;
}
}
/** Outputs a message that the pipeline has started. */
public PTransform<PBegin, PDone> signalStart() {
return new PublishStart(startTopicPath);
}
/**
* Outputs a success message when {@code successPredicate} is evaluated to true.
*
* <p>{@code successPredicate} is a {@link SerializableFunction} that accepts a set of currently
* captured events and returns true when the set satisfies the success criteria.
*
* <p>If {@code successPredicate} is evaluated to false, then it will be re-evaluated when next
* event becomes available.
*
* <p>If {@code successPredicate} is evaluated to true, then a success will be signaled and {@link
* #waitForSuccess(Duration)} will unblock.
*
* <p>If {@code successPredicate} throws, then failure will be signaled and {@link
* #waitForSuccess(Duration)} will unblock.
*/
public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate) {
return new PublishSuccessWhen<>(coder, formatter, successPredicate, resultTopicPath);
}
/**
* Invocation of {@link #signalSuccessWhen(Coder, SerializableFunction, SerializableFunction)}
* with {@link Object#toString} as the formatter.
*/
public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(
Coder<T> coder, SerializableFunction<Set<T>, Boolean> successPredicate) {
return signalSuccessWhen(coder, T::toString, successPredicate);
}
/**
* Future that waits for a start signal for {@code duration}.
*
* <p>This future must be created before running the pipeline. A subscription must exist prior to
* the start signal being published, which occurs immediately upon pipeline startup.
*/
public Supplier<Void> waitForStart(Duration duration) throws IOException {
SubscriptionPath startSubscriptionPath =
PubsubClient.subscriptionPathFromName(
pipelineOptions.getProject(),
"start-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));
pubsub.createSubscription(
startTopicPath, startSubscriptionPath, (int) duration.getStandardSeconds());
return Suppliers.memoize(
() -> {
try {
String result = pollForResultForDuration(startSubscriptionPath, duration);
checkState(START_SIGNAL_MESSAGE.equals(result));
return null;
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
/** Wait for a success signal for {@code duration}. */
public void waitForSuccess(Duration duration) throws IOException {
SubscriptionPath resultSubscriptionPath =
PubsubClient.subscriptionPathFromName(
pipelineOptions.getProject(),
"result-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));
pubsub.createSubscription(
resultTopicPath, resultSubscriptionPath, (int) duration.getStandardSeconds());
String result = pollForResultForDuration(resultSubscriptionPath, duration);
if (!RESULT_SUCCESS_MESSAGE.equals(result)) {
throw new AssertionError(result);
}
}
private String pollForResultForDuration(
SubscriptionPath signalSubscriptionPath, Duration duration) throws IOException {
List<PubsubClient.IncomingMessage> signal = null;
DateTime endPolling = DateTime.now().plus(duration.getMillis());
do {
try {
signal = pubsub.pull(DateTime.now().getMillis(), signalSubscriptionPath, 1, false);
pubsub.acknowledge(
signalSubscriptionPath, signal.stream().map(m -> m.ackId).collect(toList()));
break;
} catch (StatusRuntimeException e) {
if (!Status.DEADLINE_EXCEEDED.equals(e.getStatus())) {
LOG.warn(
"(Will retry) Error while polling {} for signal: {}",
signalSubscriptionPath,
e.getStatus());
}
sleep(500);
}
} while (DateTime.now().isBefore(endPolling));
if (signal == null) {
throw new AssertionError(
String.format(
"Did not receive signal on %s in %ss",
signalSubscriptionPath, duration.getStandardSeconds()));
}
return new String(signal.get(0).elementBytes, UTF_8);
}
private void sleep(long t) {
try {
Thread.sleep(t);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
/** {@link PTransform} that signals once when the pipeline has started. */
static class PublishStart extends PTransform<PBegin, PDone> {
private final TopicPath startTopicPath;
PublishStart(TopicPath startTopicPath) {
this.startTopicPath = startTopicPath;
}
@Override
public PDone expand(PBegin input) {
return input
.apply("Start signal", Create.of(START_SIGNAL_MESSAGE))
.apply(PubsubIO.writeStrings().to(startTopicPath.getPath()));
}
}
/** {@link PTransform} that for validates whether elements seen so far match success criteria. */
static class PublishSuccessWhen<T> extends PTransform<PCollection<? extends T>, POutput> {
private final Coder<T> coder;
private final SerializableFunction<T, String> formatter;
private final SerializableFunction<Set<T>, Boolean> successPredicate;
private final TopicPath resultTopicPath;
PublishSuccessWhen(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate,
TopicPath resultTopicPath) {
this.coder = coder;
this.formatter = formatter;
this.successPredicate = successPredicate;
this.resultTopicPath = resultTopicPath;
}
@Override
public POutput expand(PCollection<? extends T> input) {
return input
// assign a dummy key and global window,
// this is needed to accumulate all observed events in the same state cell
.apply(Window.into(new GlobalWindows()))
.apply(WithKeys.of("dummyKey"))
.apply(
"checkAllEventsForSuccess",
ParDo.of(new StatefulPredicateCheck<>(coder, formatter, successPredicate)))
// signal the success/failure to the result topic
.apply("publishSuccess", PubsubIO.writeStrings().to(resultTopicPath.getPath()));
}
}
/**
* Stateful {@link DoFn} which caches the elements it sees and checks whether they satisfy the
* predicate.
*
* <p>When predicate is satisfied outputs "SUCCESS". If predicate throws exception, outputs
* "FAILURE".
*/
static class StatefulPredicateCheck<T> extends DoFn<KV<String, ? extends T>, String> {
private final SerializableFunction<T, String> formatter;
private SerializableFunction<Set<T>, Boolean> successPredicate;
// keep all events seen so far in the state cell
private static final String SEEN_EVENTS = "seenEvents";
@StateId(SEEN_EVENTS)
private final StateSpec<BagState<T>> seenEvents;
StatefulPredicateCheck(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate) {
this.seenEvents = StateSpecs.bag(coder);
this.formatter = formatter;
this.successPredicate = successPredicate;
}
@ProcessElement
public void processElement(
ProcessContext context, @StateId(SEEN_EVENTS) BagState<T> seenEvents) {
seenEvents.add(context.element().getValue());
ImmutableSet<T> eventsSoFar = ImmutableSet.copyOf(seenEvents.read());
// check if all elements seen so far satisfy the success predicate
try {
if (successPredicate.apply(eventsSoFar)) {
context.output("SUCCESS");
}
} catch (Throwable e) {
context.output("FAILURE: " + e.getMessage());
}
}
}
}