blob: e9a31877115b30fc96a3a5c89d02e3806ce2ecb0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.aws2.sns;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.http.HttpStatus;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.sns.SnsClient;
import software.amazon.awssdk.services.sns.model.GetTopicAttributesRequest;
import software.amazon.awssdk.services.sns.model.GetTopicAttributesResponse;
import software.amazon.awssdk.services.sns.model.InternalErrorException;
import software.amazon.awssdk.services.sns.model.PublishRequest;
import software.amazon.awssdk.services.sns.model.PublishResponse;
/**
* {@link PTransform}s for writing to <a href="https://aws.amazon.com/sns/">SNS</a>.
*
* <h3>Writing to SNS</h3>
*
* <p>Example usage:
*
* <pre>{@code
* PCollection<String> data = ...;
*
* data.apply(SnsIO.<String>write()
* .withPublishRequestFn(m -> PublishRequest.builder().topicArn("topicArn").message(m).build())
* .withTopicArn("topicArn")
* .withRetryConfiguration(
* SnsIO.RetryConfiguration.create(
* 4, org.joda.time.Duration.standardSeconds(10)))
* .withSnsClientProvider(new BasicSnsClientProvider(awsCredentialsProvider, region));
* }</pre>
*
* <p>As a client, you need to provide at least the following things:
*
* <ul>
* <li>SNS topic arn you're going to publish to
* <li>Retry Configuration
* <li>AwsCredentialsProvider, which you can pass on to BasicSnsClientProvider
* <li>publishRequestFn, a function to convert your message into PublishRequest
* </ul>
*/
@Experimental(Experimental.Kind.SOURCE_SINK)
public final class SnsIO {
// Write data tp SNS
public static <T> Write<T> write() {
return new AutoValue_SnsIO_Write.Builder().build();
}
/**
* A POJO encapsulating a configuration for retry behavior when issuing requests to SNS. A retry
* will be attempted until the maxAttempts or maxDuration is exceeded, whichever comes first, for
* any of the following exceptions:
*
* <ul>
* <li>{@link IOException}
* </ul>
*/
@AutoValue
public abstract static class RetryConfiguration implements Serializable {
@VisibleForTesting
static final RetryPredicate DEFAULT_RETRY_PREDICATE = new DefaultRetryPredicate();
abstract int getMaxAttempts();
abstract Duration getMaxDuration();
abstract RetryPredicate getRetryPredicate();
abstract Builder builder();
public static RetryConfiguration create(int maxAttempts, Duration maxDuration) {
checkArgument(maxAttempts > 0, "maxAttempts should be greater than 0");
checkArgument(
maxDuration != null && maxDuration.isLongerThan(Duration.ZERO),
"maxDuration should be greater than 0");
return new AutoValue_SnsIO_RetryConfiguration.Builder()
.setMaxAttempts(maxAttempts)
.setMaxDuration(maxDuration)
.setRetryPredicate(DEFAULT_RETRY_PREDICATE)
.build();
}
@AutoValue.Builder
abstract static class Builder {
abstract Builder setMaxAttempts(int maxAttempts);
abstract Builder setMaxDuration(Duration maxDuration);
abstract Builder setRetryPredicate(RetryPredicate retryPredicate);
abstract RetryConfiguration build();
}
/**
* An interface used to control if we retry the SNS Publish call when a {@link Throwable}
* occurs. If {@link RetryPredicate#test(Object)} returns true, {@link Write} tries to resend
* the requests to the Solr server if the {@link RetryConfiguration} permits it.
*/
@FunctionalInterface
interface RetryPredicate extends Predicate<Throwable>, Serializable {}
private static class DefaultRetryPredicate implements RetryPredicate {
private static final ImmutableSet<Integer> ELIGIBLE_CODES =
ImmutableSet.of(HttpStatus.SC_SERVICE_UNAVAILABLE);
@Override
public boolean test(Throwable throwable) {
return (throwable instanceof IOException
|| (throwable instanceof InternalErrorException)
|| (throwable instanceof InternalErrorException
&& ELIGIBLE_CODES.contains(((InternalErrorException) throwable).statusCode())));
}
}
}
/** Implementation of {@link #write}. */
@AutoValue
public abstract static class Write<T>
extends PTransform<PCollection<T>, PCollection<PublishResponse>> {
@Nullable
abstract String getTopicArn();
@Nullable
abstract SerializableFunction<T, PublishRequest> getPublishRequestFn();
@Nullable
abstract SnsClientProvider getSnsClientProvider();
@Nullable
abstract RetryConfiguration getRetryConfiguration();
abstract Builder<T> builder();
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setTopicArn(String topicArn);
abstract Builder<T> setPublishRequestFn(
SerializableFunction<T, PublishRequest> publishRequestFn);
abstract Builder<T> setSnsClientProvider(SnsClientProvider snsClientProvider);
abstract Builder<T> setRetryConfiguration(RetryConfiguration retryConfiguration);
abstract Write<T> build();
}
/**
* Specify the SNS topic which will be used for writing, this name is mandatory.
*
* @param topicArn topicArn
*/
public Write<T> withTopicArn(String topicArn) {
return builder().setTopicArn(topicArn).build();
}
/**
* Specify a function for converting a message into PublishRequest object, this function is
* mandatory.
*
* @param publishRequestFn publishRequestFn
*/
public Write<T> withPublishRequestFn(SerializableFunction<T, PublishRequest> publishRequestFn) {
return builder().setPublishRequestFn(publishRequestFn).build();
}
/**
* Allows to specify custom {@link SnsClientProvider}. {@link SnsClientProvider} creates new
* {@link SnsClient} which is later used for writing to a SNS topic.
*/
public Write<T> withSnsClientProvider(SnsClientProvider awsClientsProvider) {
return builder().setSnsClientProvider(awsClientsProvider).build();
}
/**
* Specify {@link AwsCredentialsProvider} and region to be used to write to SNS. If you need
* more sophisticated credential protocol, then you should look at {@link
* Write#withSnsClientProvider(SnsClientProvider)}.
*/
public Write<T> withSnsClientProvider(
AwsCredentialsProvider credentialsProvider, String region) {
return withSnsClientProvider(credentialsProvider, region, null);
}
/**
* Specify {@link AwsCredentialsProvider} and region to be used to write to SNS. If you need
* more sophisticated credential protocol, then you should look at {@link
* Write#withSnsClientProvider(SnsClientProvider)}.
*
* <p>The {@code serviceEndpoint} sets an alternative service host. This is useful to execute
* the tests with Kinesis service emulator.
*/
public Write<T> withSnsClientProvider(
AwsCredentialsProvider credentialsProvider, String region, URI serviceEndpoint) {
return withSnsClientProvider(
new BasicSnsClientProvider(credentialsProvider, region, serviceEndpoint));
}
/**
* Provides configuration to retry a failed request to publish a message to SNS. Users should
* consider that retrying might compound the underlying problem which caused the initial
* failure. Users should also be aware that once retrying is exhausted the error is surfaced to
* the runner which <em>may</em> then opt to retry the current partition in entirety or abort if
* the max number of retries of the runner is completed. Retrying uses an exponential backoff
* algorithm, with minimum backoff of 5 seconds and then surfacing the error once the maximum
* number of retries or maximum configuration duration is exceeded.
*
* <p>Example use:
*
* <pre>{@code
* SnsIO.write()
* .withRetryConfiguration(SnsIO.RetryConfiguration.create(5, Duration.standardMinutes(1))
* ...
* }</pre>
*
* @param retryConfiguration the rules which govern the retry behavior
* @return the {@link Write} with retrying configured
*/
public Write<T> withRetryConfiguration(RetryConfiguration retryConfiguration) {
checkArgument(retryConfiguration != null, "retryConfiguration is required");
return builder().setRetryConfiguration(retryConfiguration).build();
}
private static boolean isTopicExists(SnsClient client, String topicArn) {
try {
GetTopicAttributesRequest getTopicAttributesRequest =
GetTopicAttributesRequest.builder().topicArn(topicArn).build();
GetTopicAttributesResponse topicAttributesResponse =
client.getTopicAttributes(getTopicAttributesRequest);
return topicAttributesResponse != null
&& topicAttributesResponse.sdkHttpResponse().statusCode() == 200;
} catch (Exception e) {
throw e;
}
}
@Override
public PCollection<PublishResponse> expand(PCollection<T> input) {
checkArgument(getTopicArn() != null, "withTopicArn() is required");
checkArgument(getPublishRequestFn() != null, "withPublishRequestFn() is required");
checkArgument(getSnsClientProvider() != null, "withSnsClientProvider() is required");
checkArgument(
isTopicExists(getSnsClientProvider().getSnsClient(), getTopicArn()),
"Topic arn %s does not exist",
getTopicArn());
return input.apply(ParDo.of(new SnsWriterFn<>(this)));
}
static class SnsWriterFn<T> extends DoFn<T, PublishResponse> {
@VisibleForTesting
static final String RETRY_ATTEMPT_LOG = "Error writing to SNS. Retry attempt[%d]";
private static final Duration RETRY_INITIAL_BACKOFF = Duration.standardSeconds(5);
private transient FluentBackoff retryBackoff; // defaults to no retries
private static final Logger LOG = LoggerFactory.getLogger(SnsWriterFn.class);
private static final Counter SNS_WRITE_FAILURES =
Metrics.counter(SnsWriterFn.class, "SNS_Write_Failures");
private final Write spec;
private transient SnsClient producer;
SnsWriterFn(Write spec) {
this.spec = spec;
}
@Setup
public void setup() throws Exception {
// Initialize SnsPublisher
producer = spec.getSnsClientProvider().getSnsClient();
retryBackoff =
FluentBackoff.DEFAULT
.withMaxRetries(0) // default to no retrying
.withInitialBackoff(RETRY_INITIAL_BACKOFF);
if (spec.getRetryConfiguration() != null) {
retryBackoff =
retryBackoff
.withMaxRetries(spec.getRetryConfiguration().getMaxAttempts() - 1)
.withMaxCumulativeBackoff(spec.getRetryConfiguration().getMaxDuration());
}
}
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
PublishRequest request =
(PublishRequest) spec.getPublishRequestFn().apply(context.element());
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = retryBackoff.backoff();
int attempt = 0;
while (true) {
attempt++;
try {
PublishResponse pr = producer.publish(request);
context.output(pr);
break;
} catch (Exception ex) {
// Fail right away if there is no retry configuration
if (spec.getRetryConfiguration() == null
|| !spec.getRetryConfiguration().getRetryPredicate().test(ex)) {
SNS_WRITE_FAILURES.inc();
LOG.info("Unable to publish message {} due to {} ", request.message(), ex);
throw new IOException("Error writing to SNS (no attempt made to retry)", ex);
}
if (!BackOffUtils.next(sleeper, backoff)) {
throw new IOException(
String.format(
"Error writing to SNS after %d attempt(s). No more attempts allowed",
attempt),
ex);
} else {
// Note: this used in test cases to verify behavior
LOG.warn(String.format(RETRY_ATTEMPT_LOG, attempt), ex);
}
}
}
}
@Teardown
public void tearDown() {
if (producer != null) {
producer.close();
producer = null;
}
}
}
}
}