blob: d0753cc799d25f7141908c4ec716de6d1ffbe6fe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.pubsub;
import static org.apache.beam.sdk.io.gcp.pubsub.TestPubsub.createTopicName;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import com.google.api.gax.rpc.ApiException;
import com.google.cloud.pubsub.v1.AckReplyConsumer;
import com.google.cloud.pubsub.v1.MessageReceiver;
import com.google.cloud.pubsub.v1.Subscriber;
import com.google.cloud.pubsub.v1.SubscriptionAdminClient;
import com.google.cloud.pubsub.v1.SubscriptionAdminSettings;
import com.google.cloud.pubsub.v1.TopicAdminClient;
import com.google.cloud.pubsub.v1.TopicAdminSettings;
import com.google.pubsub.v1.PushConfig;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SubscriptionPath;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestPipelineOptions;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.joda.time.Seconds;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Test rule which observes elements of the {@link PCollection} and checks whether they match the
* success criteria.
*
* <p>Uses a random temporary Pubsub topic for synchronization.
*/
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class TestPubsubSignal implements TestRule {
private static final Logger LOG = LoggerFactory.getLogger(TestPubsubSignal.class);
private static final String RESULT_TOPIC_NAME = "result";
private static final String RESULT_SUCCESS_MESSAGE = "SUCCESS";
private static final String START_TOPIC_NAME = "start";
private static final String START_SIGNAL_MESSAGE = "START SIGNAL";
private static final Integer DEFAULT_ACK_DEADLINE_SECONDS = 60;
private static final String NO_ID_ATTRIBUTE = null;
private static final String NO_TIMESTAMP_ATTRIBUTE = null;
private final TestPubsubOptions pipelineOptions;
private final String pubsubEndpoint;
private @Nullable TopicAdminClient topicAdmin = null;
private @Nullable SubscriptionAdminClient subscriptionAdmin = null;
private @Nullable TopicPath resultTopicPath = null;
private @Nullable TopicPath startTopicPath = null;
/**
* Creates an instance of this rule.
*
* <p>Loads GCP configuration from {@link TestPipelineOptions}.
*/
public static TestPubsubSignal create() {
TestPubsubOptions options = TestPipeline.testingPipelineOptions().as(TestPubsubOptions.class);
return new TestPubsubSignal(options);
}
private TestPubsubSignal(TestPubsubOptions pipelineOptions) {
this.pipelineOptions = pipelineOptions;
this.pubsubEndpoint = PubsubOptions.targetForRootUrl(this.pipelineOptions.getPubsubRootUrl());
}
@Override
public Statement apply(Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
if (topicAdmin != null || subscriptionAdmin != null) {
throw new AssertionError(
"Pubsub client was not shutdown in previous test. "
+ "Topic path is'"
+ resultTopicPath
+ "'. "
+ "Current test: "
+ description.getDisplayName());
}
try {
initializePubsub(description);
base.evaluate();
} finally {
tearDown();
}
}
};
}
private void initializePubsub(Description description) throws IOException {
topicAdmin =
TopicAdminClient.create(
TopicAdminSettings.newBuilder()
.setCredentialsProvider(pipelineOptions::getGcpCredential)
.setEndpoint(pubsubEndpoint)
.build());
subscriptionAdmin =
SubscriptionAdminClient.create(
SubscriptionAdminSettings.newBuilder()
.setCredentialsProvider(pipelineOptions::getGcpCredential)
.setEndpoint(pubsubEndpoint)
.build());
// Example topic name:
// integ-test-TestClassName-testMethodName-2018-12-11-23-32-333-<random-long>-result
TopicPath resultTopicPathTmp =
PubsubClient.topicPathFromName(
pipelineOptions.getProject(), createTopicName(description, RESULT_TOPIC_NAME));
TopicPath startTopicPathTmp =
PubsubClient.topicPathFromName(
pipelineOptions.getProject(), createTopicName(description, START_TOPIC_NAME));
topicAdmin.createTopic(resultTopicPathTmp.getPath());
topicAdmin.createTopic(startTopicPathTmp.getPath());
// Set these after successful creation; this signals that they need teardown
resultTopicPath = resultTopicPathTmp;
startTopicPath = startTopicPathTmp;
}
private void tearDown() throws IOException {
if (subscriptionAdmin == null || topicAdmin == null) {
return;
}
try {
if (resultTopicPath != null) {
for (String subscriptionPath :
topicAdmin.listTopicSubscriptions(resultTopicPath.getPath()).iterateAll()) {
subscriptionAdmin.deleteSubscription(subscriptionPath);
}
topicAdmin.deleteTopic(resultTopicPath.getPath());
}
if (startTopicPath != null) {
for (String subscriptionPath :
topicAdmin.listTopicSubscriptions(startTopicPath.getPath()).iterateAll()) {
subscriptionAdmin.deleteSubscription(subscriptionPath);
}
topicAdmin.deleteTopic(startTopicPath.getPath());
}
} finally {
subscriptionAdmin.close();
topicAdmin.close();
subscriptionAdmin = null;
topicAdmin = null;
resultTopicPath = null;
startTopicPath = null;
}
}
/** Outputs a message that the pipeline has started. */
public PTransform<PBegin, PDone> signalStart() {
return new PublishStart(startTopicPath);
}
/**
* Outputs a success message when {@code successPredicate} is evaluated to true.
*
* <p>{@code successPredicate} is a {@link SerializableFunction} that accepts a set of currently
* captured events and returns true when the set satisfies the success criteria.
*
* <p>If {@code successPredicate} is evaluated to false, then it will be re-evaluated when next
* event becomes available.
*
* <p>If {@code successPredicate} is evaluated to true, then a success will be signaled and {@link
* #waitForSuccess(Duration)} will unblock.
*
* <p>If {@code successPredicate} throws, then failure will be signaled and {@link
* #waitForSuccess(Duration)} will unblock.
*/
public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate) {
return new PublishSuccessWhen<>(coder, formatter, successPredicate, resultTopicPath);
}
/**
* Invocation of {@link #signalSuccessWhen(Coder, SerializableFunction, SerializableFunction)}
* with {@link Object#toString} as the formatter.
*/
public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(
Coder<T> coder, SerializableFunction<Set<T>, Boolean> successPredicate) {
return signalSuccessWhen(coder, T::toString, successPredicate);
}
/**
* Future that waits for a start signal for {@code duration}.
*
* <p>This future must be created before running the pipeline. A subscription must exist prior to
* the start signal being published, which occurs immediately upon pipeline startup.
*/
public Supplier<Void> waitForStart(Duration duration) throws IOException {
SubscriptionPath startSubscriptionPath =
PubsubClient.subscriptionPathFromName(
pipelineOptions.getProject(),
"start-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));
subscriptionAdmin.createSubscription(
startSubscriptionPath.getPath(),
startTopicPath.getPath(),
PushConfig.getDefaultInstance(),
(int) duration.getStandardSeconds());
return Suppliers.memoize(
() -> {
try {
String result = pollForResultForDuration(startSubscriptionPath, duration);
checkState(START_SIGNAL_MESSAGE.equals(result));
return null;
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
try {
subscriptionAdmin.deleteSubscription(startSubscriptionPath.getPath());
} catch (ApiException e) {
LOG.error(String.format("Leaked PubSub subscription '%s'", startSubscriptionPath));
}
}
});
}
/** Wait for a success signal for {@code duration}. */
public void waitForSuccess(Duration duration) throws IOException {
SubscriptionPath resultSubscriptionPath =
PubsubClient.subscriptionPathFromName(
pipelineOptions.getProject(),
"result-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));
subscriptionAdmin.createSubscription(
resultSubscriptionPath.getPath(),
resultTopicPath.getPath(),
PushConfig.getDefaultInstance(),
(int) duration.getStandardSeconds());
String result = pollForResultForDuration(resultSubscriptionPath, duration);
try {
subscriptionAdmin.deleteSubscription(resultSubscriptionPath.getPath());
} catch (ApiException e) {
LOG.error(String.format("Leaked PubSub subscription '%s'", resultSubscriptionPath));
}
if (!RESULT_SUCCESS_MESSAGE.equals(result)) {
throw new AssertionError(result);
}
}
private String pollForResultForDuration(
SubscriptionPath signalSubscriptionPath, Duration timeoutDuration) throws IOException {
AtomicReference<String> result = new AtomicReference<>(null);
MessageReceiver receiver =
(com.google.pubsub.v1.PubsubMessage message, AckReplyConsumer replyConsumer) -> {
// Ignore empty messages
if (message.getData().isEmpty()) {
replyConsumer.ack();
}
if (result.compareAndSet(null, message.getData().toStringUtf8())) {
replyConsumer.ack();
} else {
replyConsumer.nack();
}
};
Subscriber subscriber =
Subscriber.newBuilder(signalSubscriptionPath.getPath(), receiver)
.setCredentialsProvider(pipelineOptions::getGcpCredential)
.setEndpoint(pubsubEndpoint)
.build();
subscriber.startAsync();
DateTime startTime = new DateTime();
int timeoutSeconds = timeoutDuration.toStandardSeconds().getSeconds();
while (result.get() == null
&& Seconds.secondsBetween(startTime, new DateTime()).getSeconds() < timeoutSeconds) {
try {
Thread.sleep(1000);
} catch (InterruptedException ignored) {
}
}
subscriber.stopAsync();
subscriber.awaitTerminated();
if (result.get() == null) {
throw new AssertionError(
String.format(
"Did not receive signal on %s in %ss",
signalSubscriptionPath, timeoutDuration.getStandardSeconds()));
}
return result.get();
}
private void sleep(long t) {
try {
Thread.sleep(t);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
/** {@link PTransform} that signals once when the pipeline has started. */
static class PublishStart extends PTransform<PBegin, PDone> {
private final TopicPath startTopicPath;
PublishStart(TopicPath startTopicPath) {
this.startTopicPath = startTopicPath;
}
@Override
public PDone expand(PBegin input) {
return input
.apply("Start signal", Create.of(START_SIGNAL_MESSAGE))
.apply(PubsubIO.writeStrings().to(startTopicPath.getPath()));
}
}
/** {@link PTransform} that for validates whether elements seen so far match success criteria. */
static class PublishSuccessWhen<T> extends PTransform<PCollection<? extends T>, POutput> {
private final Coder<T> coder;
private final SerializableFunction<T, String> formatter;
private final SerializableFunction<Set<T>, Boolean> successPredicate;
private final TopicPath resultTopicPath;
PublishSuccessWhen(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate,
TopicPath resultTopicPath) {
this.coder = coder;
this.formatter = formatter;
this.successPredicate = successPredicate;
this.resultTopicPath = resultTopicPath;
}
@Override
public POutput expand(PCollection<? extends T> input) {
return input
// assign a dummy key and global window,
// this is needed to accumulate all observed events in the same state cell
.apply(Window.into(new GlobalWindows()))
.apply(WithKeys.of("dummyKey"))
.apply(
"checkAllEventsForSuccess",
ParDo.of(new StatefulPredicateCheck<>(coder, formatter, successPredicate)))
// signal the success/failure to the result topic
.apply("publishSuccess", PubsubIO.writeStrings().to(resultTopicPath.getPath()));
}
}
/**
* Stateful {@link DoFn} which caches the elements it sees and checks whether they satisfy the
* predicate.
*
* <p>When predicate is satisfied outputs "SUCCESS". If predicate throws exception, outputs
* "FAILURE".
*/
static class StatefulPredicateCheck<T> extends DoFn<KV<String, ? extends T>, String> {
private final SerializableFunction<T, String> formatter;
private SerializableFunction<Set<T>, Boolean> successPredicate;
// keep all events seen so far in the state cell
private static final String SEEN_EVENTS = "seenEvents";
@StateId(SEEN_EVENTS)
private final StateSpec<BagState<T>> seenEvents;
StatefulPredicateCheck(
Coder<T> coder,
SerializableFunction<T, String> formatter,
SerializableFunction<Set<T>, Boolean> successPredicate) {
this.seenEvents = StateSpecs.bag(coder);
this.formatter = formatter;
this.successPredicate = successPredicate;
}
@ProcessElement
public void processElement(
ProcessContext context, @StateId(SEEN_EVENTS) BagState<T> seenEvents) {
seenEvents.add(context.element().getValue());
ImmutableSet<T> eventsSoFar = ImmutableSet.copyOf(seenEvents.read());
// check if all elements seen so far satisfy the success predicate
try {
if (successPredicate.apply(eventsSoFar)) {
context.output("SUCCESS");
}
} catch (Throwable e) {
context.output("FAILURE: " + e.getMessage());
}
}
}
}