blob: bcf4499516edb10847ab9a151468cc2632484265 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. The ASF licenses this file to You
* under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. For additional information regarding
* copyright in this work, please see the NOTICE file in the top level
* directory of this distribution.
*/
package org.apache.usergrid.persistence.queue.impl;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sns.AmazonSNSClient;
import com.amazonaws.services.sns.model.*;
import com.amazonaws.services.sqs.AmazonSQSClient;
import com.amazonaws.services.sqs.model.*;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.inject.Inject;
import com.google.inject.assistedinject.Assisted;
import org.apache.usergrid.persistence.core.astyanax.CassandraFig;
import org.apache.usergrid.persistence.core.guicyfig.ClusterFig;
import org.apache.usergrid.persistence.queue.*;
import org.apache.usergrid.persistence.queue.Queue;
import org.apache.usergrid.persistence.queue.util.AmazonNotificationUtils;
import org.apache.usergrid.persistence.core.rx.RxTaskScheduler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutionException;
import rx.Observable;
public class SNSQueueManagerImpl implements QueueManager {
private static final Logger logger = LoggerFactory.getLogger(SNSQueueManagerImpl.class);
private final QueueScope scope;
private final QueueFig fig;
private final ClusterFig clusterFig;
private final CassandraFig cassandraFig;
private final AmazonSQSClient sqs;
private final AmazonSNSClient sns;
private final RxTaskScheduler rxTaskScheduler;
private final JsonFactory JSON_FACTORY = new JsonFactory();
private final ObjectMapper mapper = new ObjectMapper(JSON_FACTORY);
private final LoadingCache<String, String> writeTopicArnMap = CacheBuilder.newBuilder()
.maximumSize(1000)
.build(new CacheLoader<String, String>() {
@Override
public String load(String queueName)
throws Exception {
return setupTopics(queueName);
}
});
private final LoadingCache<String, Queue> readQueueUrlMap = CacheBuilder.newBuilder()
.maximumSize(1000)
.build(new CacheLoader<String, Queue>() {
@Override
public Queue load(String queueName) throws Exception {
Queue queue = null;
try {
GetQueueUrlResult result = sqs.getQueueUrl(queueName);
queue = new Queue(result.getQueueUrl());
} catch (QueueDoesNotExistException queueDoesNotExistException) {
logger.error("Queue {} does not exist, will create", queueName);
} catch (Exception e) {
logger.error("failed to get queue from service", e);
throw e;
}
if (queue == null) {
String url = AmazonNotificationUtils.createQueue(sqs, queueName, fig);
queue = new Queue(url);
}
setupTopics(queueName);
return queue;
}
});
@Inject
public SNSQueueManagerImpl(@Assisted QueueScope scope, QueueFig fig, ClusterFig clusterFig, CassandraFig cassandraFig, final RxTaskScheduler rxTaskScheduler) {
this.scope = scope;
this.fig = fig;
this.clusterFig = clusterFig;
this.cassandraFig = cassandraFig;
this.rxTaskScheduler = rxTaskScheduler;
try {
sqs = createSQSClient(getRegion());
sns = createSNSClient(getRegion());
} catch (Exception e) {
throw new RuntimeException("Error setting up mapper", e);
}
}
private String setupTopics(final String queueName)
throws Exception {
logger.info("Setting up setupTopics SNS/SQS...");
String primaryTopicArn = AmazonNotificationUtils.getTopicArn(sns, queueName, true);
if (logger.isDebugEnabled()) logger.debug("SNS/SQS Setup: primaryTopicArn=" + primaryTopicArn);
String queueUrl = AmazonNotificationUtils.getQueueUrlByName(sqs, queueName);
String primaryQueueArn = AmazonNotificationUtils.getQueueArnByName(sqs, queueName);
if (logger.isDebugEnabled()) logger.debug("SNS/SQS Setup: primaryQueueArn=" + primaryQueueArn);
if (primaryQueueArn == null) {
if (logger.isDebugEnabled())
logger.debug("SNS/SQS Setup: primaryQueueArn is null, creating queue...");
queueUrl = AmazonNotificationUtils.createQueue(sqs, queueName, fig);
primaryQueueArn = AmazonNotificationUtils.getQueueArnByUrl(sqs, queueUrl);
if (logger.isDebugEnabled())
logger.debug("SNS/SQS Setup: New Queue URL=[{}] ARN=[{}]", queueUrl, primaryQueueArn);
}
try {
SubscribeRequest primarySubscribeRequest = new SubscribeRequest(primaryTopicArn, "sqs", primaryQueueArn);
sns.subscribe(primarySubscribeRequest);
// ensure the SNS primary topic has permission to send to the primary SQS queue
List<String> primaryTopicArnList = new ArrayList<>();
primaryTopicArnList.add(primaryTopicArn);
AmazonNotificationUtils.setQueuePermissionsToReceive(sqs, queueUrl, primaryTopicArnList);
} catch (AmazonServiceException e) {
logger.error(String.format("Unable to subscribe PRIMARY queue=[%s] to topic=[%s]", queueUrl, primaryTopicArn), e);
}
if (fig.isMultiRegion() && scope.getRegionImplementation() == QueueScope.RegionImplementation.ALL) {
String multiRegion = fig.getRegionList();
if (logger.isDebugEnabled())
logger.debug("MultiRegion Setup specified, regions: [{}]", multiRegion);
String[] regionNames = multiRegion.split(",");
final Map<String, String> arrQueueArns = new HashMap<>(regionNames.length + 1);
final Map<String, String> topicArns = new HashMap<>(regionNames.length + 1);
arrQueueArns.put(primaryQueueArn, fig.getRegion());
topicArns.put(primaryTopicArn, fig.getRegion());
for (String regionName : regionNames) {
regionName = regionName.trim();
Regions regions = Regions.fromName(regionName);
Region region = Region.getRegion(regions);
AmazonSQSClient sqsClient = createSQSClient(region);
AmazonSNSClient snsClient = createSNSClient(region); // do this stuff synchronously
// getTopicArn will create the SNS topic if it doesn't exist
String topicArn = AmazonNotificationUtils.getTopicArn(snsClient, queueName, true);
topicArns.put(topicArn, regionName);
// create the SQS queue if it doesn't exist
String queueArn = AmazonNotificationUtils.getQueueArnByName(sqsClient, queueName);
if (queueArn == null) {
queueUrl = AmazonNotificationUtils.createQueue(sqsClient, queueName, fig);
queueArn = AmazonNotificationUtils.getQueueArnByUrl(sqsClient, queueUrl);
}
arrQueueArns.put(queueArn, regionName);
}
logger.debug("Creating Subscriptions...");
for (Map.Entry<String, String> queueArnEntry : arrQueueArns.entrySet()) {
String queueARN = queueArnEntry.getKey();
String strSqsRegion = queueArnEntry.getValue();
Regions sqsRegions = Regions.fromName(strSqsRegion);
Region sqsRegion = Region.getRegion(sqsRegions);
AmazonSQSClient subscribeSqsClient = createSQSClient(sqsRegion);
// ensure the URL used to subscribe is for the correct name/region
String subscribeQueueUrl = AmazonNotificationUtils.getQueueUrlByName(subscribeSqsClient, queueName);
// this list used later for adding permissions to queues
List<String> topicArnList = new ArrayList<>();
for (Map.Entry<String, String> topicArnEntry : topicArns.entrySet()) {
String topicARN = topicArnEntry.getKey();
topicArnList.add(topicARN);
String strSnsRegion = topicArnEntry.getValue();
Regions snsRegions = Regions.fromName(strSnsRegion);
Region snsRegion = Region.getRegion(snsRegions);
AmazonSNSClient subscribeSnsClient = createSNSClient(snsRegion); // do this stuff synchronously
SubscribeRequest subscribeRequest = new SubscribeRequest(topicARN, "sqs", queueARN);
try {
logger.info("Subscribing Queue ARN/Region=[{} / {}] and Topic ARN/Region=[{} / {}]",
queueARN,
strSqsRegion,
topicARN,
strSnsRegion
);
SubscribeResult subscribeResult = subscribeSnsClient.subscribe(subscribeRequest);
String subscriptionARN = subscribeResult.getSubscriptionArn();
if(logger.isDebugEnabled()){
logger.debug("Successfully subscribed Queue ARN=[{}] to Topic ARN=[{}], subscription ARN=[{}]", queueARN, topicARN, subscriptionARN);
}
} catch (Exception e) {
logger.error(String.format("ERROR Subscribing Queue ARN/Region=[%s / %s] and Topic ARN/Region=[%s / %s]",
queueARN,
strSqsRegion,
topicARN,
strSnsRegion), e);
}
}
logger.info("Adding permission to receive messages...");
// add permission to each queue, providing a list of topics that it's subscribed to
AmazonNotificationUtils.setQueuePermissionsToReceive(subscribeSqsClient, subscribeQueueUrl, topicArnList);
}
}
return primaryTopicArn;
}
/**
* The Synchronous SNS client is used for creating topics and subscribing queues.
*
*/
private AmazonSNSClient createSNSClient(final Region region) {
final UsergridAwsCredentialsProvider ugProvider = new UsergridAwsCredentialsProvider();
final AmazonSNSClient sns = new AmazonSNSClient(ugProvider.getCredentials());
sns.setRegion(region);
return sns;
}
private String getName() {
String name = clusterFig.getClusterName() + "_" + cassandraFig.getApplicationKeyspace() + "_" + scope.getName() + "_" + scope.getRegionImplementation();
name = name.toLowerCase(); //user lower case values
Preconditions.checkArgument(name.length() <= 80, "Your name must be < than 80 characters");
return name;
}
public Queue getReadQueue() {
String queueName = getName();
try {
return readQueueUrlMap.get(queueName);
} catch (ExecutionException ee) {
throw new RuntimeException(ee);
}
}
public String getWriteTopicArn() {
try {
return writeTopicArnMap.get(getName());
} catch (ExecutionException ee) {
throw new RuntimeException(ee);
}
}
@Override
public rx.Observable<QueueMessage> getMessages(final int limit,
final int transactionTimeout,
final int waitTime,
final Class klass) {
if (sqs == null) {
logger.error("SQS is null - was not initialized properly");
return rx.Observable.empty();
}
String url = getReadQueue().getUrl();
if (logger.isDebugEnabled()) logger.debug("Getting up to {} messages from {}", limit, url);
ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(url);
receiveMessageRequest.setMaxNumberOfMessages(limit);
receiveMessageRequest.setVisibilityTimeout(transactionTimeout / 1000);
receiveMessageRequest.setWaitTimeSeconds(waitTime / 1000);
try {
ReceiveMessageResult result = sqs.receiveMessage(receiveMessageRequest);
List<Message> messages = result.getMessages();
if (logger.isDebugEnabled()) logger.debug("Received {} messages from {}", messages.size(), url);
List<QueueMessage> queueMessages = new ArrayList<>(messages.size());
for (Message message : messages) {
Object body;
try {
final JsonNode bodyNode = mapper.readTree(message.getBody());
JsonNode bodyObj = bodyNode.has("Message") ? bodyNode.get("Message") : bodyNode;
body = fromString(bodyObj.textValue(), klass);
} catch (Exception e) {
logger.error(String.format("failed to deserialize message: %s", message.getBody()), e);
throw new RuntimeException(e);
}
QueueMessage queueMessage = new QueueMessage(message.getMessageId(), message.getReceiptHandle(), body, message.getAttributes().get("type"));
queueMessages.add(queueMessage);
}
return rx.Observable.from(queueMessages);
} catch (com.amazonaws.services.sqs.model.QueueDoesNotExistException dne) {
logger.error(String.format("Queue does not exist! [%s]", url), dne);
} catch (Exception e) {
logger.error(String.format("Programming error getting messages from queue=[%s] exist!", url), e);
}
return rx.Observable.from(new ArrayList<>(0));
}
@Override
public long getQueueDepth() {
String key = "ApproximateNumberOfMessages";
try {
GetQueueAttributesResult result = sqs.getQueueAttributes(getReadQueue().getUrl(), Collections.singletonList(key));
String depthString = result.getAttributes().get(key);
return depthString != null ? Long.parseLong(depthString) : 0;
}catch (Exception e){
logger.error("Exception getting queue depth",e);
return -1;
}
}
@Override
public void sendMessages(final List bodies) throws IOException {
if (sns == null) {
logger.error("SNS client is null, perhaps it failed to initialize successfully");
return;
}
for (Object body : bodies) {
sendMessage(body);
}
}
@Override
public void sendMessage(final Object body) throws IOException {
Observable.just(body).doOnNext(message->{
if (sns == null) {
logger.error("SNS client is null, perhaps it failed to initialize successfully");
return;
}
final String stringBody;
try {
stringBody = toString(body);
String topicArn = getWriteTopicArn();
if (logger.isDebugEnabled()){
logger.debug("Publishing Message...{} to arn: {}", stringBody, topicArn);
}
PublishRequest publishRequest = new PublishRequest(topicArn, stringBody);
// publish message to SNS
PublishResult publishResult = sns.publish(publishRequest);
if(logger.isDebugEnabled()){
logger.debug("Successfully published... messageID=[{}], arn=[{}]",
publishResult.getMessageId(), publishRequest.getTopicArn());
}
} catch (IOException e) {
logger.error("Unable to convert queue object to a string message body", e);
}
}).doOnError(e ->{
logger.error("Error while publishing SNS message: ", e);
}).subscribeOn(rxTaskScheduler.getAsyncIOScheduler() ).subscribe();
}
@Override
public void commitMessage(final QueueMessage queueMessage) {
String url = getReadQueue().getUrl();
if (logger.isDebugEnabled())
logger.debug("Commit message {} to queue {}", queueMessage.getMessageId(), url);
sqs.deleteMessage(new DeleteMessageRequest()
.withQueueUrl(url)
.withReceiptHandle(queueMessage.getHandle()));
}
@Override
public void commitMessages(final List<QueueMessage> queueMessages) {
String url = getReadQueue().getUrl();
if (logger.isDebugEnabled()) logger.debug("Commit messages {} to queue {}", queueMessages.size(), url);
List<DeleteMessageBatchRequestEntry> entries = new ArrayList<>();
for (QueueMessage message : queueMessages) {
entries.add(new DeleteMessageBatchRequestEntry(message.getMessageId(), message.getHandle()));
}
DeleteMessageBatchRequest request = new DeleteMessageBatchRequest(url, entries);
DeleteMessageBatchResult result = sqs.deleteMessageBatch(request);
boolean successful = result.getFailed().size() <= 0;
if (!successful) {
for (BatchResultErrorEntry failed : result.getFailed()) {
logger.error("Commit failed reason: {} messages id: {}", failed.getMessage(), failed.getId());
}
}
}
/**
* Read the object from Base64 string.
*/
private Object fromString(final String s, final Class klass)
throws IOException, ClassNotFoundException {
Object o = mapper.readValue(s, klass);
return o;
}
/**
* Write the object to a Base64 string.
*/
private String toString(final Object o) throws IOException {
return mapper.writeValueAsString(o);
}
/**
* Get the region
*
* @return
*/
private Region getRegion() {
Regions regions = Regions.fromName(fig.getRegion());
return Region.getRegion(regions);
}
/**
* Create the SQS client for the specified settings
*/
private AmazonSQSClient createSQSClient(final Region region) {
final UsergridAwsCredentialsProvider ugProvider = new UsergridAwsCredentialsProvider();
final AmazonSQSClient sqs = new AmazonSQSClient(ugProvider.getCredentials());
sqs.setRegion(region);
return sqs;
}
}