blob: 6cda0d54250cfad8c7b6c260b66b277a311edc06 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.geode.benchmark.redis.tasks;
import static org.assertj.core.api.Assertions.assertThat;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.geode.benchmark.redis.tests.PubSubBenchmarkConfiguration;
import org.apache.geode.perftest.Task;
import org.apache.geode.perftest.TestContext;
public class SubscribeRedisTask implements Task {
private static final Logger logger = LoggerFactory.getLogger(SubscribeRedisTask.class);
// TextContext keys for shared objects between the SubscribeTask (before) and
// the PubSubEndTask (after)
private static final String SUBSCRIBERS_CONTEXT_KEY = "subscribers";
private static final String SUBSCRIBERS_THREAD_POOL = "threadPool";
private final List<RedisClientManager> subscriberClientManagers;
private final boolean validate;
private final PubSubBenchmarkConfiguration pubSubConfig;
public SubscribeRedisTask(final PubSubBenchmarkConfiguration pubSubConfig,
final List<RedisClientManager> subscriberClientManagers,
final boolean validate) {
this.pubSubConfig = pubSubConfig;
logger.info(
"Initialized: SubscribeRedisTask numChannels={}, numMessagesPerChannel={}, messageLength={}, validate={}, useChannelPattern={}",
pubSubConfig.getNumChannels(), pubSubConfig.getNumMessagesPerChannelOperation(),
pubSubConfig.getMessageLength(), validate, pubSubConfig.shouldUseChannelPattern());
this.subscriberClientManagers = subscriberClientManagers;
this.validate = validate;
}
@Override
public void run(final TestContext context) throws Exception {
final CyclicBarrier barrier = pubSubConfig.getCyclicBarrier();
// save subscribers in the TestContext, as this will be shared with
// the after tasks which will call shutdown()
final List<Subscriber> subscribers = subscriberClientManagers.stream()
.map(cm -> new Subscriber(cm.get(), barrier, context))
.collect(Collectors.toList());
context.setAttribute(SUBSCRIBERS_CONTEXT_KEY, subscribers);
// save thread pool in TestContext, so it can be shutdown cleanly after
final ExecutorService subscriberThreadPool =
Executors.newFixedThreadPool(subscriberClientManagers.size());
context.setAttribute(SUBSCRIBERS_THREAD_POOL, subscriberThreadPool);
for (final Subscriber subscriber : subscribers) {
subscriber.subscribeAsync(subscriberThreadPool, context);
}
// sleep to try to make sure subscribe is complete before continuing
Thread.sleep(1000);
}
public static void shutdown(final TestContext cxt) throws Exception {
// precondition: method run has been previously executed in this Worker
// and therefore subscribers and threadPool are available
@SuppressWarnings("unchecked")
final List<Subscriber> subscribers =
(List<Subscriber>) cxt.getAttribute(SUBSCRIBERS_CONTEXT_KEY);
for (final SubscribeRedisTask.Subscriber subscriber : subscribers) {
subscriber.waitForCompletion(cxt);
}
logger.info("Shutting down thread pool…");
final ExecutorService threadPool = (ExecutorService) cxt.getAttribute(SUBSCRIBERS_THREAD_POOL);
threadPool.shutdownNow();
// noinspection ResultOfMethodCallIgnored
threadPool.awaitTermination(5, TimeUnit.MINUTES);
logger.info("Thread pool terminated");
}
public class Subscriber {
private final AtomicInteger messagesReceived;
private final int numMessagesExpected;
private final RedisClient client;
private final RedisClient.SubscriptionListener listener;
private CompletableFuture<Void> future;
Subscriber(final RedisClient client, final CyclicBarrier barrier, final TestContext context) {
this.messagesReceived = new AtomicInteger(0);
this.client = client;
numMessagesExpected =
pubSubConfig.getNumChannels() * pubSubConfig.getNumMessagesPerChannelOperation();
listener = client.createSubscriptionListener(pubSubConfig,
(String channel, String message, RedisClient.Unsubscriber unsubscriber) -> {
if (channel.equals(pubSubConfig.getControlChannel())) {
if (message.equals(pubSubConfig.getEndMessage())) {
unsubscriber.unsubscribe(pubSubConfig.getAllSubscribeChannels());
logger.info("Subscriber thread unsubscribed.");
} else {
throw new AssertionError("Unrecognized control message: " + message);
}
} else if (receiveMessageAndIsComplete(channel, message, context)) {
try {
reset();
barrier.await(10, TimeUnit.SECONDS);
} catch (final TimeoutException e) {
throw new RuntimeException("Subscriber timed out while waiting on barrier");
} catch (final InterruptedException | BrokenBarrierException ignored) {
}
}
return null;
});
}
public void subscribeAsync(final ExecutorService threadPool, final TestContext context) {
future = CompletableFuture.runAsync(
() -> {
final List<String> channels = pubSubConfig.getAllSubscribeChannels();
if (pubSubConfig.shouldUseChannelPattern()) {
context.logProgress("Subscribing to channel patterns " + channels);
client.psubscribe(listener, channels.toArray(new String[] {}));
} else {
context.logProgress("Subscribing to channels " + channels);
client.subscribe(listener, channels.toArray(new String[] {}));
}
}, threadPool);
future.whenComplete((result, ex) -> {
logger.info("Subscriber thread completed");
if (ex != null) {
ex.printStackTrace();
context.logProgress(String.format("Subscriber completed with exception '%s')", ex));
}
});
}
public void waitForCompletion(final TestContext ctx) throws Exception {
if (future == null) {
return;
}
assertThat(future.get(10, TimeUnit.SECONDS)).isNull();
}
// Receive a message and return true if all messages have been received
private boolean receiveMessageAndIsComplete(final String channel, final String message,
final TestContext context) {
if (validate) {
context.logProgress(String.format(
"Received message %s of length %d on channel %s; messagesReceived=%d; messagesExpected=%d",
message, message.length(), channel, messagesReceived.get() + 1, numMessagesExpected));
assertThat(message.length()).isEqualTo(pubSubConfig.getMessageLength());
}
return messagesReceived.incrementAndGet() >= numMessagesExpected;
}
private void reset() {
messagesReceived.set(0);
}
}
}