Add a concept for retrying messages
patch by David Capwell; reviewed by Alex Petrov for CASSANDRA-19856
diff --git a/src/java/org/apache/cassandra/db/SystemKeyspace.java b/src/java/org/apache/cassandra/db/SystemKeyspace.java
index 8709453..05a2437 100644
--- a/src/java/org/apache/cassandra/db/SystemKeyspace.java
+++ b/src/java/org/apache/cassandra/db/SystemKeyspace.java
@@ -120,6 +120,7 @@
import org.apache.cassandra.utils.MD5Digest;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.TimeUUID;
+import org.apache.cassandra.utils.TriFunction;
import org.apache.cassandra.utils.concurrent.Future;
import static java.lang.String.format;
@@ -1938,8 +1939,8 @@
int counter = 0;
for (UntypedResultSet.Row row : resultSet)
{
- if (onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
- row.getString("query_string"),
+ if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+ row.getString("query_string"),
row.has("logged_keyspace") ? row.getString("logged_keyspace") : null))
counter++;
}
@@ -1953,18 +1954,14 @@
int counter = 0;
for (UntypedResultSet.Row row : resultSet)
{
- if (onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
- row.getString("query_string"),
+ if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+ row.getString("query_string"),
row.has("logged_keyspace") ? row.getString("logged_keyspace") : null))
counter++;
}
return counter;
}
- public static interface TriFunction<A, B, C, D> {
- D accept(A var1, B var2, C var3);
- }
-
public static void saveTopPartitions(TableMetadata metadata, String topType, Collection<TopPartitionTracker.TopPartition> topPartitions, long lastUpdate)
{
String cql = String.format("INSERT INTO %s.%s (keyspace_name, table_name, top_type, top, last_update) values (?, ?, ?, ?, ?)", SchemaConstants.SYSTEM_KEYSPACE_NAME, TOP_PARTITIONS);
diff --git a/src/java/org/apache/cassandra/net/MessageDelivery.java b/src/java/org/apache/cassandra/net/MessageDelivery.java
index 0b7890c..0d052cb 100644
--- a/src/java/org/apache/cassandra/net/MessageDelivery.java
+++ b/src/java/org/apache/cassandra/net/MessageDelivery.java
@@ -19,19 +19,27 @@
package org.apache.cassandra.net;
import java.util.Collection;
+import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.RequestFailureReason;
import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.Backoff;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.concurrent.Accumulator;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
import org.apache.cassandra.utils.concurrent.CountDownLatch;
import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.Promise;
+
+import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
public interface MessageDelivery
{
@@ -74,9 +82,161 @@
public <REQ, RSP> void sendWithCallback(Message<REQ> message, InetAddressAndPort to, RequestCallback<RSP> cb);
public <REQ, RSP> void sendWithCallback(Message<REQ> message, InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType specifyConnection);
public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ> message, InetAddressAndPort to);
+
+ public default <REQ, RSP> Future<Message<RSP>> sendWithRetries(Backoff backoff, RetryScheduler retryThreads,
+ Verb verb, REQ request,
+ Iterator<InetAddressAndPort> candidates,
+ RetryPredicate shouldRetry,
+ RetryErrorMessage errorMessage)
+ {
+ Promise<Message<RSP>> promise = new AsyncPromise<>();
+ this.<REQ, RSP>sendWithRetries(backoff, retryThreads, verb, request, candidates,
+ (attempt, success, failure) -> {
+ if (failure != null) promise.tryFailure(failure);
+ else promise.trySuccess(success);
+ },
+ shouldRetry, errorMessage);
+ return promise;
+ }
+
+ public default <REQ, RSP> void sendWithRetries(Backoff backoff, RetryScheduler retryThreads,
+ Verb verb, REQ request,
+ Iterator<InetAddressAndPort> candidates,
+ OnResult<RSP> onResult,
+ RetryPredicate shouldRetry,
+ RetryErrorMessage errorMessage)
+ {
+ sendWithRetries(this, backoff, retryThreads, verb, request, candidates, onResult, shouldRetry, errorMessage, 0);
+ }
public <V> void respond(V response, Message<?> message);
public default void respondWithFailure(RequestFailureReason reason, Message<?> message)
{
send(Message.failureResponse(message.id(), message.expiresAtNanos(), reason), message.respondTo());
}
+
+ interface OnResult<T>
+ {
+ void result(int attempt, @Nullable Message<T> success, @Nullable Throwable failure);
+ }
+
+ interface RetryPredicate
+ {
+ boolean test(int attempt, InetAddressAndPort from, RequestFailureReason failure);
+ }
+
+ interface RetryErrorMessage
+ {
+ String apply(int attempt, ResponseFailureReason retryFailure, @Nullable InetAddressAndPort from, @Nullable RequestFailureReason reason);
+ }
+
+ private static <REQ, RSP> void sendWithRetries(MessageDelivery messaging,
+ Backoff backoff, RetryScheduler retryThreads,
+ Verb verb, REQ request,
+ Iterator<InetAddressAndPort> candidates,
+ OnResult<RSP> onResult,
+ RetryPredicate shouldRetry,
+ RetryErrorMessage errorMessage,
+ int attempt)
+ {
+ if (Thread.currentThread().isInterrupted())
+ {
+ onResult.result(attempt, null, new InterruptedException(errorMessage.apply(attempt, ResponseFailureReason.Interrupted, null, null)));
+ return;
+ }
+ if (!candidates.hasNext())
+ {
+ onResult.result(attempt, null, new NoMoreCandidatesException(errorMessage.apply(attempt, ResponseFailureReason.NoMoreCandidates, null, null)));
+ return;
+ }
+ class Request implements RequestCallbackWithFailure<RSP>
+ {
+ @Override
+ public void onResponse(Message<RSP> msg)
+ {
+ onResult.result(attempt, msg, null);
+ }
+
+ @Override
+ public void onFailure(InetAddressAndPort from, RequestFailureReason failure)
+ {
+ if (!backoff.mayRetry(attempt))
+ {
+ onResult.result(attempt, null, new MaxRetriesException(attempt, errorMessage.apply(attempt, ResponseFailureReason.MaxRetries, from, failure)));
+ return;
+ }
+ if (!shouldRetry.test(attempt, from, failure))
+ {
+ onResult.result(attempt, null, new FailedResponseException(from, failure, errorMessage.apply(attempt, ResponseFailureReason.Rejected, from, failure)));
+ return;
+ }
+ try
+ {
+ retryThreads.schedule(() -> sendWithRetries(messaging, backoff, retryThreads, verb, request, candidates, onResult, shouldRetry, errorMessage, attempt + 1),
+ backoff.computeWaitTime(attempt), backoff.unit());
+ }
+ catch (Throwable t)
+ {
+ onResult.result(attempt, null, new FailedScheduleException(errorMessage.apply(attempt, ResponseFailureReason.FailedSchedule, from, failure), t));
+ }
+ }
+ }
+ messaging.sendWithCallback(Message.outWithFlag(verb, request, CALL_BACK_ON_FAILURE), candidates.next(), new Request());
+ }
+
+ enum ResponseFailureReason { MaxRetries, Rejected, NoMoreCandidates, Interrupted, FailedSchedule }
+
+ interface RetryScheduler
+ {
+ void schedule(Runnable command, long delay, TimeUnit unit);
+ }
+
+ enum ImmediateRetryScheduler implements RetryScheduler
+ {
+ instance;
+
+ @Override
+ public void schedule(Runnable command, long delay, TimeUnit unit)
+ {
+ command.run();
+ }
+ }
+
+ class NoMoreCandidatesException extends IllegalStateException
+ {
+ public NoMoreCandidatesException(String s)
+ {
+ super(s);
+ }
+ }
+
+ class FailedResponseException extends IllegalStateException
+ {
+ public final InetAddressAndPort from;
+ public final RequestFailureReason failure;
+
+ public FailedResponseException(InetAddressAndPort from, RequestFailureReason failure, String message)
+ {
+ super(message);
+ this.from = from;
+ this.failure = failure;
+ }
+ }
+
+ class MaxRetriesException extends IllegalStateException
+ {
+ public final int attempts;
+ public MaxRetriesException(int attempts, String message)
+ {
+ super(message);
+ this.attempts = attempts;
+ }
+ }
+
+ class FailedScheduleException extends IllegalStateException
+ {
+ public FailedScheduleException(String message, Throwable cause)
+ {
+ super(message, cause);
+ }
+ }
}
diff --git a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
index f0cbf78..835f90f 100644
--- a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
+++ b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
@@ -23,11 +23,13 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterators;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,8 +52,6 @@
import org.apache.cassandra.utils.TimeUUID;
import org.apache.cassandra.utils.concurrent.Future;
-import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
-
/**
* Base class of all repair related request/response messages.
*
@@ -138,9 +138,7 @@
{
RepairRetrySpec retrySpec = DatabaseDescriptor.getRepairRetrySpec();
RetrySpec spec = verb == Verb.VALIDATION_RSP ? retrySpec.getMerkleTreeResponseSpec() : retrySpec;
- if (!spec.isEnabled())
- return Backoff.None.INSTANCE;
- return new Backoff.ExponentialBackoff(spec.maxAttempts.value, spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(), ctx.random().get()::nextDouble);
+ return Backoff.fromConfig(ctx, spec);
}
public static Supplier<Boolean> notDone(Future<?> f)
@@ -155,98 +153,94 @@
public static <T> void sendMessageWithRetries(SharedContext ctx, Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, verb, endpoint, finalCallback, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, verb, endpoint, finalCallback);
}
public static <T> void sendMessageWithRetries(SharedContext ctx, RepairMessage request, Verb verb, InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, verb, endpoint, finalCallback, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, verb, endpoint, finalCallback);
}
public static void sendMessageWithRetries(SharedContext ctx, RepairMessage request, Verb verb, InetAddressAndPort endpoint)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, verb, endpoint, NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, verb, endpoint, NOOP_CALLBACK);
}
public static void sendMessageWithRetries(SharedContext ctx, Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, InetAddressAndPort endpoint)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, verb, endpoint, NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, verb, endpoint, NOOP_CALLBACK);
}
@VisibleForTesting
- static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff, Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, InetAddressAndPort endpoint, RequestCallback<T> finalCallback, int attempt)
+ static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff, Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
{
if (!ALLOWS_RETRY.contains(verb))
throw new AssertionError("Repair verb " + verb + " does not support retry, but a request to send with retry was given!");
- RequestCallback<T> callback = new RequestCallback<>()
- {
- @Override
- public void onResponse(Message<T> msg)
+ BiConsumer<Integer, RequestFailureReason > maybeRecordRetry = (attempt, reason) -> {
+ if (attempt <= 0)
+ return;
+ // we don't know what the prefix kind is... so use NONE... this impacts logPrefix as it will cause us to use "repair" rather than "preview repair" which may not be correct... but close enough...
+ String prefix = PreviewKind.NONE.logPrefix(request.parentRepairSession());
+ RepairMetrics.retry(verb, attempt);
+ if (reason == null)
{
- maybeRecordRetry(null);
- finalCallback.onResponse(msg);
+ noSpam.info("{} Retry of repair verb " + verb + " was successful after {} attempts", prefix, attempt);
}
-
- @Override
- public void onFailure(InetAddressAndPort from, RequestFailureReason failureReason)
+ else if (reason == RequestFailureReason.TIMEOUT)
{
- ErrorHandling allowed = errorHandlingSupported(ctx, endpoint, verb, request.parentRepairSession());
- switch (allowed)
- {
- case NONE:
- logger.error("[#{}] {} failed on {}: {}", request.parentRepairSession(), verb, from, failureReason);
- return;
- case TIMEOUT:
- finalCallback.onFailure(from, failureReason);
- return;
- case RETRY:
- int maxAttempts = backoff.maxAttempts();
- if (failureReason == RequestFailureReason.TIMEOUT && attempt < maxAttempts && allowRetry.get())
- {
- ctx.optionalTasks().schedule(() -> sendMessageWithRetries(ctx, backoff, allowRetry, request, verb, endpoint, finalCallback, attempt + 1),
- backoff.computeWaitTime(attempt), backoff.unit());
- return;
- }
- maybeRecordRetry(failureReason);
- finalCallback.onFailure(from, failureReason);
- return;
- default:
- throw new AssertionError("Unknown error handler: " + allowed);
- }
+ noSpam.warn("{} Timeout for repair verb " + verb + "; could not complete within {} attempts", prefix, attempt);
+ RepairMetrics.retryTimeout(verb);
}
-
- private void maybeRecordRetry(@Nullable RequestFailureReason reason)
+ else
{
- if (attempt <= 0)
- return;
- // we don't know what the prefix kind is... so use NONE... this impacts logPrefix as it will cause us to use "repair" rather than "preview repair" which may not be correct... but close enough...
- String prefix = PreviewKind.NONE.logPrefix(request.parentRepairSession());
- RepairMetrics.retry(verb, attempt);
- if (reason == null)
- {
- noSpam.info("{} Retry of repair verb " + verb + " was successful after {} attempts", prefix, attempt);
- }
- else if (reason == RequestFailureReason.TIMEOUT)
- {
- noSpam.warn("{} Timeout for repair verb " + verb + "; could not complete within {} attempts", prefix, attempt);
- RepairMetrics.retryTimeout(verb);
- }
- else
- {
- noSpam.warn("{} {} failure for repair verb " + verb + "; could not complete within {} attempts", prefix, reason, attempt);
- RepairMetrics.retryFailure(verb);
- }
- }
-
- @Override
- public boolean invokeOnFailure()
- {
- return true;
+ noSpam.warn("{} {} failure for repair verb " + verb + "; could not complete within {} attempts", prefix, reason, attempt);
+ RepairMetrics.retryFailure(verb);
}
};
- ctx.messaging().sendWithCallback(Message.outWithFlag(verb, request, CALL_BACK_ON_FAILURE),
- endpoint,
- callback);
+ ctx.messaging().sendWithRetries(backoff, ctx.optionalTasks()::schedule,
+ verb, request, Iterators.cycle(endpoint),
+ (int attempt, Message<T> msg, Throwable failure) -> {
+ if (failure == null)
+ {
+ maybeRecordRetry.accept(attempt, null);
+ finalCallback.onResponse(msg);
+ }
+ },
+ (attempt, from, failure) -> {
+ ErrorHandling allowed = errorHandlingSupported(ctx, endpoint, verb, request.parentRepairSession());
+ switch (allowed)
+ {
+ case NONE:
+ logger.error("[#{}] {} failed on {}: {}", request.parentRepairSession(), verb, from, failure);
+ return false;
+ case TIMEOUT:
+ finalCallback.onFailure(from, failure);
+ return false;
+ case RETRY:
+ if (failure == RequestFailureReason.TIMEOUT && allowRetry.get())
+ return true;
+ maybeRecordRetry.accept(attempt, failure);
+ finalCallback.onFailure(from, failure);
+ return false;
+ default:
+ throw new AssertionError("Unknown error handler: " + allowed);
+ }
+ },
+ (attempt, retryReason, from, failure) -> {
+ switch (retryReason)
+ {
+ case MaxRetries:
+ maybeRecordRetry.accept(attempt, failure);
+ finalCallback.onFailure(from, failure);
+ return null;
+ case Interrupted:
+ case Rejected:
+ case FailedSchedule:
+ return null;
+ default:
+ throw new UnsupportedOperationException(retryReason.name());
+ }
+ });
}
public static void sendMessageWithFailureCB(SharedContext ctx, Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, InetAddressAndPort endpoint, RepairFailureCallback failureCallback)
diff --git a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
index 79f0b7c..0ea055b 100644
--- a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
+++ b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
@@ -39,6 +39,7 @@
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.metrics.TCMMetrics;
import org.apache.cassandra.net.Message;
+import org.apache.cassandra.net.MessageDelivery;
import org.apache.cassandra.net.MessagingService;
import org.apache.cassandra.net.RequestCallbackWithFailure;
import org.apache.cassandra.net.Verb;
@@ -47,6 +48,7 @@
import org.apache.cassandra.tcm.log.LocalLog;
import org.apache.cassandra.tcm.log.LogState;
import org.apache.cassandra.utils.AbstractIterator;
+import org.apache.cassandra.utils.Backoff;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.concurrent.AsyncPromise;
import org.apache.cassandra.utils.concurrent.Future;
@@ -201,51 +203,45 @@
public static <REQ, RSP> void sendWithCallbackAsync(Promise<RSP> promise, Verb verb, REQ request, CandidateIterator candidates, Retry retryPolicy)
{
- class Request implements RequestCallbackWithFailure<RSP>
- {
- void retry()
- {
- if (promise.isCancelled() || promise.isDone())
- return;
- if (Thread.currentThread().isInterrupted())
- promise.setFailure(new InterruptedException());
- if (!candidates.hasNext())
- promise.tryFailure(new IllegalStateException(String.format("Ran out of candidates while sending %s: %s", verb, candidates)));
-
- MessagingService.instance().sendWithCallback(Message.out(verb, request), candidates.next(), this);
- }
-
- @Override
- public void onResponse(Message<RSP> msg)
- {
- promise.trySuccess(msg.payload);
- }
-
- @Override
- public void onFailure(InetAddressAndPort from, RequestFailureReason reason)
- {
- if (reason == RequestFailureReason.NOT_CMS)
- {
- logger.debug("{} is not a member of the CMS, querying it to discover current membership", from);
- DiscoveredNodes cms = tryDiscover(from);
- candidates.addCandidates(cms);
- candidates.timeout(from);
- logger.debug("Got CMS from {}: {}, retrying on: {}", from, cms, candidates);
- }
- else
- {
- candidates.timeout(from);
- logger.warn("Got error from {}: {} when sending {}, retrying on {}", from, reason, verb, candidates);
- }
-
- if (retryPolicy.reachedMax())
- promise.tryFailure(new IllegalStateException(String.format("Could not succeed sending %s to %s after %d tries", verb, candidates, retryPolicy.tries)));
- else
- retry();
- }
- }
-
- new Request().retry();
+ //TODO (now): the retry defines how long to wait for a retry, but the old behavior scheduled the message right away... should this be delayed as well?
+ MessagingService.instance().<REQ, RSP>sendWithRetries(Backoff.fromRetry(retryPolicy), MessageDelivery.ImmediateRetryScheduler.instance,
+ verb, request, candidates,
+ (attempt, success, failure) -> {
+ if (failure != null) promise.tryFailure(failure);
+ else promise.trySuccess(success.payload);
+ },
+ (attempt, from, failure) -> {
+ if (promise.isDone() || promise.isCancelled())
+ return false;
+ if (failure == RequestFailureReason.NOT_CMS)
+ {
+ logger.debug("{} is not a member of the CMS, querying it to discover current membership", from);
+ DiscoveredNodes cms = tryDiscover(from);
+ candidates.addCandidates(cms);
+ candidates.timeout(from);
+ logger.debug("Got CMS from {}: {}, retrying on: {}", from, cms, candidates);
+ }
+ else
+ {
+ candidates.timeout(from);
+ logger.warn("Got error from {}: {} when sending {}, retrying on {}", from, failure, verb, candidates);
+ }
+ return true;
+ },
+ (attempt, reason, from, failure) -> {
+ switch (reason)
+ {
+ case NoMoreCandidates:
+ return String.format("Ran out of candidates while sending %s: %s", verb, candidates);
+ case MaxRetries:
+ return String.format("Could not succeed sending %s to %s after %d tries", verb, candidates, retryPolicy.tries);
+ case Interrupted:
+ case FailedSchedule:
+ return null;
+ default:
+ throw new UnsupportedOperationException(reason.name());
+ }
+ });
}
private static DiscoveredNodes tryDiscover(InetAddressAndPort ep)
diff --git a/src/java/org/apache/cassandra/tcm/Retry.java b/src/java/org/apache/cassandra/tcm/Retry.java
index a1215fd..703e590 100644
--- a/src/java/org/apache/cassandra/tcm/Retry.java
+++ b/src/java/org/apache/cassandra/tcm/Retry.java
@@ -59,9 +59,14 @@
public void maybeSleep()
{
+ sleepUninterruptibly(computeSleepFor(), TimeUnit.MILLISECONDS);
+ }
+
+ public long computeSleepFor()
+ {
tries++;
retryMeter.mark();
- sleepUninterruptibly(sleepFor(), TimeUnit.MILLISECONDS);
+ return sleepFor();
}
protected abstract long sleepFor();
diff --git a/src/java/org/apache/cassandra/utils/Backoff.java b/src/java/org/apache/cassandra/utils/Backoff.java
index 2f0b7e2..7974dbf 100644
--- a/src/java/org/apache/cassandra/utils/Backoff.java
+++ b/src/java/org/apache/cassandra/utils/Backoff.java
@@ -21,23 +21,55 @@
import java.util.concurrent.TimeUnit;
import java.util.function.DoubleSupplier;
+import org.apache.cassandra.config.RetrySpec;
+import org.apache.cassandra.repair.SharedContext;
+import org.apache.cassandra.tcm.Retry;
+
public interface Backoff
{
- /**
- * @return max attempts allowed, {@code == 0} implies no retries are allowed
- */
- int maxAttempts();
- long computeWaitTime(int retryCount);
+ boolean mayRetry(int attempt);
+ long computeWaitTime(int attempt);
TimeUnit unit();
+ static Backoff fromRetry(Retry retry)
+ {
+ return new Backoff()
+ {
+ @Override
+ public boolean mayRetry(int attempt)
+ {
+ return !retry.reachedMax();
+ }
+
+ @Override
+ public long computeWaitTime(int retryCount)
+ {
+ return retry.computeSleepFor();
+ }
+
+ @Override
+ public TimeUnit unit()
+ {
+ return TimeUnit.MILLISECONDS;
+ }
+ };
+ }
+
+ static Backoff fromConfig(SharedContext ctx, RetrySpec spec)
+ {
+ if (!spec.isEnabled())
+ return Backoff.None.INSTANCE;
+ return new Backoff.ExponentialBackoff(spec.maxAttempts.value, spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(), ctx.random().get()::nextDouble);
+ }
+
enum None implements Backoff
{
INSTANCE;
@Override
- public int maxAttempts()
+ public boolean mayRetry(int attempt)
{
- return 0;
+ return false;
}
@Override
@@ -68,13 +100,18 @@
this.randomSource = randomSource;
}
- @Override
public int maxAttempts()
{
return maxAttempts;
}
@Override
+ public boolean mayRetry(int attempt)
+ {
+ return attempt < maxAttempts;
+ }
+
+ @Override
public long computeWaitTime(int retryCount)
{
long baseTimeMillis = baseSleepTimeMillis * (1L << retryCount);
diff --git a/src/java/org/apache/cassandra/utils/TriFunction.java b/src/java/org/apache/cassandra/utils/TriFunction.java
new file mode 100644
index 0000000..c280850
--- /dev/null
+++ b/src/java/org/apache/cassandra/utils/TriFunction.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.utils;
+
+@FunctionalInterface
+public interface TriFunction<A, B, C, D>
+{
+ D apply(A var1, B var2, C var3);
+}
diff --git a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
index ffb1a0f..e206af3 100644
--- a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
+++ b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
@@ -18,6 +18,7 @@
package org.apache.cassandra.concurrent;
+import java.sql.Timestamp;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
@@ -33,13 +34,25 @@
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+
import accord.utils.Gens;
import accord.utils.RandomSource;
import org.apache.cassandra.utils.Clock;
+import org.apache.cassandra.utils.Generators;
+import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.UncheckedInterruptedException;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
+import static org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.SHUTTING_DOWN_NOW;
+import static org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.TERMINATED;
+import static org.apache.cassandra.concurrent.Interruptible.State.INTERRUPTED;
+import static org.apache.cassandra.concurrent.Interruptible.State.NORMAL;
+import static org.apache.cassandra.concurrent.Interruptible.State.SHUTTING_DOWN;
+import static org.apache.cassandra.utils.Generators.toGen;
public class SimulatedExecutorFactory implements ExecutorFactory, Clock
{
@@ -79,15 +92,57 @@
private final RandomSource rs;
private final long startTimeNanos;
+ @Nullable
+ private final Consumer<Throwable> onError;
private final PriorityQueue<Item> queue = new PriorityQueue<>();
private long seq = 0;
private long nowNanos;
private int repeatedTasks = 0;
+ public SimulatedExecutorFactory(RandomSource rs, Consumer<Throwable> onError)
+ {
+ this(rs, toGen(Generators.TIMESTAMP_GEN.map(Timestamp::getTime)).mapToLong(TimeUnit.MILLISECONDS::toNanos).next(rs), onError);
+ }
+
+ public SimulatedExecutorFactory(RandomSource rs)
+ {
+ this(rs, null);
+ }
+
public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos)
{
+ this(rs, startTimeNanos, null);
+ }
+
+ public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos, Consumer<Throwable> onError)
+ {
this.rs = rs;
this.startTimeNanos = startTimeNanos;
+ this.onError = onError;
+ }
+
+ private void maybeAddFailureListener(Future<?> task)
+ {
+ if (onError == null) return;
+ task.addCallback((s, f) -> {
+ if (f != null)
+ onError.accept(f);
+ });
+ }
+
+ public boolean hasWork()
+ {
+ return queue.size() > repeatedTasks;
+ }
+
+ public boolean processAny()
+ {
+ Item item = queue.poll();
+ if (item == null)
+ return false;
+ nowNanos = Math.max(nowNanos + 1, item.runAtNanos);
+ item.action.run();
+ return true;
}
public boolean processOne()
@@ -95,12 +150,14 @@
// if we count the repeated tasks, then processAll will never complete
if (queue.size() == repeatedTasks)
return false;
- Item item = queue.poll();
- if (item == null)
- return false;
- nowNanos = Math.max(nowNanos + 1, item.runAtNanos);
- item.action.run();
- return true;
+ return processAny();
+ }
+
+ public void processAll()
+ {
+ while (processOne())
+ {
+ }
}
@Override
@@ -153,9 +210,92 @@
}
@Override
- public Interruptible infiniteLoop(String name, Interruptible.Task task, InfiniteLoopExecutor.SimulatorSafe simulatorSafe, InfiniteLoopExecutor.Daemon daemon, InfiniteLoopExecutor.Interrupts interrupts)
+ public Interruptible infiniteLoop(String name,
+ Interruptible.Task task,
+ InfiniteLoopExecutor.SimulatorSafe simulatorSafe,
+ InfiniteLoopExecutor.Daemon daemon,
+ InfiniteLoopExecutor.Interrupts interrupts)
{
- throw new UnsupportedOperationException("TODO");
+ var delegate = new UnorderedScheduledExecutorService();
+ class Capture { UnorderedScheduledExecutorService.ScheduledFuture<?> f;}
+ Capture c = new Capture();
+ class I implements Interruptible
+ {
+ private Object state = NORMAL;
+ private boolean interrupted = false;
+ private void runOne()
+ {
+ Object cur = state;
+ if (cur == SHUTTING_DOWN_NOW || cur == SHUTTING_DOWN)
+ {
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ return;
+ }
+
+ if (cur == NORMAL && interrupted) cur = INTERRUPTED;
+ try
+ {
+ task.run((State) cur);
+ interrupted = false;
+ }
+ catch (TerminateException ignore)
+ {
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ }
+ catch (UncheckedInterruptedException | InterruptedException e)
+ {
+ interrupted = false;
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ }
+ catch (Throwable t)
+ {
+ if (onError != null)
+ onError.accept(t);
+ }
+ }
+
+ @Override
+ public void interrupt()
+ {
+ interrupted = true;
+ }
+
+ @Override
+ public boolean isTerminated()
+ {
+ return state == TERMINATED;
+ }
+
+ @Override
+ public void shutdown()
+ {
+ if (state != TERMINATED && state != SHUTTING_DOWN_NOW)
+ state = SHUTTING_DOWN;
+ }
+
+ @Override
+ public Object shutdownNow()
+ {
+ if (state != TERMINATED)
+ state = SHUTTING_DOWN_NOW;
+ return null;
+ }
+
+ @Override
+ public boolean awaitTermination(long timeout, TimeUnit units)
+ {
+ return isTerminated();
+ }
+ }
+ I i = new I();
+ c.f = delegate.scheduleAtFixedRate(i::runOne, 0, 0, NANOSECONDS);
+ return i;
}
@Override
@@ -329,7 +469,9 @@
public void execute(Runnable command)
{
checkNotShutdown();
- queue.add(new Item(nowWithJitter(), SimulatedExecutorFactory.this.seq++, taskFor(command)));
+ var action = taskFor(command);
+ maybeAddFailureListener(action);
+ queue.add(new Item(nowWithJitter(), SimulatedExecutorFactory.this.seq++, action));
}
protected void checkNotShutdown()
@@ -365,6 +507,7 @@
if (next == null)
return;
+ maybeAddFailureListener(next.action);
next.action.addCallback((s, f) -> afterExecution());
queue.add(next);
}
@@ -461,6 +604,8 @@
catch (Throwable t)
{
tryFailure(t);
+ if (onError != null)
+ onError.accept(t);
}
}
}
@@ -477,6 +622,7 @@
{
checkNotShutdown();
ScheduledFuture<V> task = new ScheduledFuture<>(seq++, delay, 0, NANOSECONDS, callable);
+ maybeAddFailureListener(task);
queue.add(new Item(nowWithJitter() + unit.toNanos(delay), task.sequenceNumber, task));
return task;
}
@@ -486,6 +632,7 @@
{
checkNotShutdown();
ScheduledFuture<?> task = new ScheduledFuture<>(seq++, initialDelay, period, unit, Executors.callable(command));
+ maybeAddFailureListener(task);
repeatedTasks++;
task.addCallback((s, f) -> repeatedTasks--);
queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay), task.sequenceNumber, task));
@@ -497,6 +644,7 @@
{
checkNotShutdown();
ScheduledFuture<?> task = new ScheduledFuture<>(seq++, initialDelay, -delay, unit, Executors.callable(command));
+ maybeAddFailureListener(task);
repeatedTasks++;
task.addCallback((s, f) -> repeatedTasks--);
queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay), task.sequenceNumber, task));
diff --git a/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
new file mode 100644
index 0000000..59d7106
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Iterators;
+import org.junit.Assert;
+import org.junit.Test;
+
+import accord.utils.RandomSource;
+import org.apache.cassandra.concurrent.ScheduledExecutorPlus;
+import org.apache.cassandra.concurrent.SimulatedExecutorFactory;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.MessageDelivery.FailedResponseException;
+import org.apache.cassandra.net.MessageDelivery.MaxRetriesException;
+import org.apache.cassandra.net.SimulatedMessageDelivery.Action;
+import org.apache.cassandra.net.SimulatedMessageDelivery.SimulatedMessageReceiver;
+import org.apache.cassandra.tcm.ClusterMetadataService;
+import org.apache.cassandra.tcm.StubClusterMetadataService;
+import org.apache.cassandra.utils.Backoff;
+import org.mockito.Mockito;
+
+import static accord.utils.Property.qt;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class MessageDeliveryTest
+{
+ private static final InetAddressAndPort ID1 = InetAddressAndPort.getByNameUnchecked("127.0.0.1");
+ private static final MessageDelivery.RetryErrorMessage RETRY_ERROR_MESSAGE = (i1, i2, i3, i4) -> null;
+ private static final MessageDelivery.RetryPredicate ALWAYS_RETRY = (i1, i2, i3) -> true;
+ private static final MessageDelivery.RetryPredicate ALWAYS_REJECT = (i1, i2, i3) -> false;
+
+ static
+ {
+ DatabaseDescriptor.clientInitialization();
+ DatabaseDescriptor.setPartitionerUnsafe(Murmur3Partitioner.instance);
+ ClusterMetadataService.setInstance(StubClusterMetadataService.forTesting());
+ }
+
+ @Test
+ public void sendWithRetryFailsAfterMaxAttempts()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+ MessageDelivery messaging = simulatedMessages(rs, scheduler, failures, (i1, i2, i3) -> Action.DROP);
+
+ int expectedRetries = 3;
+ Backoff backoff = new Backoff.ExponentialBackoff(expectedRetries, 200, 1000, rs.fork()::nextDouble);
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+ scheduler::schedule,
+ Verb.ECHO_REQ, NoPayload.noPayload,
+ Iterators.cycle(ID1),
+ ALWAYS_RETRY,
+ RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+
+ assertThat(getMaxRetriesException(result).attempts).isEqualTo(expectedRetries);
+ });
+ }
+
+ @Test
+ public void sendWithRetryFirstAttempt()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+ MessageDelivery messaging = simulatedMessages(rs, scheduler, failures, (i1, i2, i3) -> Action.DELIVER);
+
+ Backoff backoff = Mockito.mock(Backoff.class);
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+ scheduler::schedule,
+ Verb.ECHO_REQ, NoPayload.noPayload,
+ Iterators.cycle(ID1),
+ ALWAYS_RETRY,
+ RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+ Mockito.verify(backoff, Mockito.never()).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).unit();
+ });
+ }
+
+ @Test
+ public void sendWithRetry()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+ int maxAttempts = 3;
+ int expectedAttempts = 1;
+ AtomicInteger attempts = new AtomicInteger(0);
+ MessageDelivery messaging = simulatedMessages(rs, scheduler, failures, (i1, i2, i3) -> attempts.incrementAndGet() >= (expectedAttempts + 1) ? Action.DELIVER : Action.DROP);
+
+ Backoff backoff = Mockito.spy(new Backoff.ExponentialBackoff(maxAttempts, 200, 1000, rs.fork()::nextDouble));
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+ scheduler::schedule,
+ Verb.ECHO_REQ, NoPayload.noPayload,
+ Iterators.cycle(ID1),
+ ALWAYS_RETRY,
+ RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+ Mockito.verify(backoff, Mockito.times(expectedAttempts)).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.times(expectedAttempts)).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.times(expectedAttempts)).unit();
+ });
+ }
+
+ @Test
+ public void sendWithRetryDontAllowRetry()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+ MessageDelivery messaging = simulatedMessages(rs, scheduler, failures, (i1, i2, i3) -> Action.DROP);
+
+ Backoff backoff = Mockito.spy(new Backoff.ExponentialBackoff(3, 200, 1000, rs.fork()::nextDouble));
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+ scheduler::schedule,
+ Verb.ECHO_REQ, NoPayload.noPayload,
+ Iterators.cycle(ID1),
+ ALWAYS_REJECT,
+ RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ FailedResponseException e = getFailedResponseException(result);
+ assertThat(e.from).isEqualTo(ID1);
+ assertThat(e.failure).isEqualTo(RequestFailureReason.TIMEOUT);
+ Mockito.verify(backoff, Mockito.times(1)).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).unit();
+ });
+ }
+
+ private static MessageDelivery simulatedMessages(RandomSource rs, ScheduledExecutorPlus scheduler, List<Throwable> failures, SimulatedMessageDelivery.ActionSupplier actionSupplier)
+ {
+ Map<InetAddressAndPort, SimulatedMessageReceiver> receivers = new HashMap<>();
+ SimulatedMessageDelivery messaging = new SimulatedMessageDelivery(ID1,
+ actionSupplier,
+ SimulatedMessageDelivery.randomDelay(rs),
+ (to, message) -> scheduler.execute(() -> receivers.get(to).recieve(message)),
+ (i1, i2, i3) -> {},
+ scheduler::schedule,
+ failures::add);
+ receivers.put(ID1, messaging.receiver(m -> messaging.respond(NoPayload.noPayload, m)));
+ return messaging;
+ }
+
+ private static FailedResponseException getFailedResponseException(Future<Message<Void>> result) throws InterruptedException
+ {
+ FailedResponseException ex;
+ try
+ {
+ result.get();
+ Assert.fail("Should have failed");
+ throw new AssertionError("Not Reachable");
+ }
+ catch (ExecutionException e)
+ {
+ ex = (FailedResponseException) e.getCause();
+ }
+ return ex;
+ }
+
+ private static MaxRetriesException getMaxRetriesException(Future<Message<Void>> result) throws InterruptedException
+ {
+ MaxRetriesException ex;
+ try
+ {
+ result.get();
+ Assert.fail("Should have failed");
+ throw new AssertionError("Not Reachable");
+ }
+ catch (ExecutionException e)
+ {
+ ex = (MaxRetriesException) e.getCause();
+ }
+ return ex;
+ }
+}
\ No newline at end of file
diff --git a/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
new file mode 100644
index 0000000..f36d04d
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+
+import accord.utils.Gens;
+import accord.utils.RandomSource;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
+import org.apache.cassandra.utils.concurrent.Future;
+
+public class SimulatedMessageDelivery implements MessageDelivery
+{
+ public enum Action { DELIVER, DELIVER_WITH_FAILURE, DROP, DROP_PARTITIONED, FAILURE }
+
+ public interface ActionSupplier
+ {
+ Action get(InetAddressAndPort self, Message<?> message, InetAddressAndPort to);
+ }
+
+ public interface NetworkDelaySupplier
+ {
+ @Nullable
+ Duration jitter(Message<?> message, InetAddressAndPort to);
+ }
+
+ public static NetworkDelaySupplier noDelay()
+ {
+ return (i1, i2) -> null;
+ }
+
+ public static NetworkDelaySupplier randomDelay(RandomSource rs)
+ {
+ class Connection
+ {
+ final InetAddressAndPort from, to;
+
+ private Connection(InetAddressAndPort from, InetAddressAndPort to)
+ {
+ this.from = from;
+ this.to = to;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Connection that = (Connection) o;
+ return from.equals(that.from) && to.equals(that.to);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "Connection{" + "from=" + from + ", to=" + to + '}';
+ }
+ }
+ final Map<Connection, LongSupplier> networkLatencies = new HashMap<>();
+ return (msg, to) -> {
+ InetAddressAndPort from = msg.from();
+ long delayNanos = networkLatencies.computeIfAbsent(new Connection(from, to), ignore -> {
+ long min = TimeUnit.MICROSECONDS.toNanos(500);
+ long maxSmall = TimeUnit.MILLISECONDS.toNanos(5);
+ long max = TimeUnit.SECONDS.toNanos(5);
+ LongSupplier small = () -> rs.nextLong(min, maxSmall);
+ LongSupplier large = () -> rs.nextLong(maxSmall, max);
+ return Gens.bools().runs(rs.nextInt(1, 11) / 100.0D, rs.nextInt(3, 15))
+ .mapToLong(b -> b ? large.getAsLong() : small.getAsLong())
+ .asLongSupplier(rs.fork());
+ }).getAsLong();
+ return Duration.ofNanos(delayNanos);
+ };
+ }
+
+ public interface Scheduler
+ {
+ void schedule(Runnable command, long delay, TimeUnit unit);
+ }
+
+ public interface DropListener
+ {
+ void onDrop(Action action, InetAddressAndPort from, Message<?> msg);
+ }
+
+ private final InetAddressAndPort self;
+ private final ActionSupplier actions;
+ private final NetworkDelaySupplier networkDelay;
+ private final BiConsumer<InetAddressAndPort, Message<?>> reciever;
+ private final DropListener onDropped;
+ private final Scheduler scheduler;
+ private final Consumer<Throwable> onError;
+ private final Map<CallbackKey, CallbackContext> callbacks = new HashMap<>();
+ private enum Status { Up, Down }
+ private Status status = Status.Up;
+
+ public SimulatedMessageDelivery(InetAddressAndPort self,
+ ActionSupplier actions,
+ NetworkDelaySupplier networkDelay,
+ BiConsumer<InetAddressAndPort, Message<?>> reciever,
+ DropListener onDropped,
+ Scheduler scheduler,
+ Consumer<Throwable> onError)
+ {
+ this.self = self;
+ this.actions = actions;
+ this.networkDelay = networkDelay;
+ this.reciever = reciever;
+ this.onDropped = onDropped;
+ this.scheduler = scheduler;
+ this.onError = onError;
+ }
+
+ public void stop()
+ {
+ callbacks.clear();
+ status = Status.Down;
+ }
+
+ @Override
+ public <REQ> void send(Message<REQ> message, InetAddressAndPort to)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, null);
+ }
+
+ @Override
+ public <REQ, RSP> void sendWithCallback(Message<REQ> message, InetAddressAndPort to, RequestCallback<RSP> cb)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, cb);
+ }
+
+ @Override
+ public <REQ, RSP> void sendWithCallback(Message<REQ> message, InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType specifyConnection)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, cb);
+ }
+
+ @Override
+ public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ> message, InetAddressAndPort to)
+ {
+ AsyncPromise<Message<RSP>> promise = new AsyncPromise<>();
+ sendWithCallback(message, to, new RequestCallback<RSP>()
+ {
+ @Override
+ public void onResponse(Message<RSP> msg)
+ {
+ promise.trySuccess(msg);
+ }
+
+ @Override
+ public void onFailure(InetAddressAndPort from, RequestFailureReason failure)
+ {
+ promise.tryFailure(new MessagingService.FailureResponseException(from, failure));
+ }
+
+ @Override
+ public boolean invokeOnFailure()
+ {
+ return true;
+ }
+ });
+ return promise;
+ }
+
+ @Override
+ public <V> void respond(V response, Message<?> message)
+ {
+ send(message.responseWith(response), message.respondTo());
+ }
+
+ private <REQ, RSP> void maybeEnqueue(Message<REQ> message, InetAddressAndPort to, @Nullable RequestCallback<RSP> callback)
+ {
+ if (status != Status.Up)
+ return;
+ CallbackContext cb;
+ if (callback != null)
+ {
+ CallbackKey key = new CallbackKey(message.id(), to);
+ if (callbacks.containsKey(key))
+ throw new AssertionError("Message id " + message.id() + " to " + to + " already has a callback");
+ cb = new CallbackContext(callback);
+ callbacks.put(key, cb);
+ }
+ else
+ {
+ cb = null;
+ }
+ Action action = actions.get(self, message, to);
+ switch (action)
+ {
+ case DELIVER:
+ deliver(message, to);
+ break;
+ case DROP:
+ case DROP_PARTITIONED:
+ onDropped.onDrop(action, to, message);
+ break;
+ case DELIVER_WITH_FAILURE:
+ deliver(message, to);
+ case FAILURE:
+ if (action == Action.FAILURE)
+ onDropped.onDrop(action, to, message);
+ if (callback != null)
+ scheduler.schedule(() -> callback.onFailure(to, RequestFailureReason.UNKNOWN),
+ message.verb().expiresAfterNanos(), TimeUnit.NANOSECONDS);
+ return;
+ default:
+ throw new UnsupportedOperationException("Unknown action type: " + action);
+ }
+ if (cb != null)
+ {
+ scheduler.schedule(() -> {
+ CallbackContext ctx = callbacks.remove(new CallbackKey(message.id(), to));
+ if (ctx != null)
+ {
+ assert ctx == cb;
+ try
+ {
+ ctx.onFailure(to, RequestFailureReason.TIMEOUT);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }, message.verb().expiresAfterNanos(), TimeUnit.NANOSECONDS);
+ }
+ }
+
+ private void deliver(Message<?> message, InetAddressAndPort to)
+ {
+ Duration delay = networkDelay.jitter(message, to);
+ if (delay == null) reciever.accept(to, message);
+ else scheduler.schedule(() -> reciever.accept(to, message), delay.toNanos(), TimeUnit.NANOSECONDS);
+ }
+
+ @SuppressWarnings("rawtypes")
+ public SimulatedMessageReceiver receiver(IVerbHandler onMessage)
+ {
+ return new SimulatedMessageReceiver(onMessage);
+ }
+
+ public class SimulatedMessageReceiver
+ {
+ @SuppressWarnings("rawtypes")
+ final IVerbHandler onMessage;
+
+ @SuppressWarnings("rawtypes")
+ public SimulatedMessageReceiver(IVerbHandler onMessage)
+ {
+ this.onMessage = onMessage;
+ }
+
+ public void recieve(Message<?> msg)
+ {
+ if (status != Status.Up)
+ return;
+ if (msg.verb().isResponse())
+ {
+ CallbackKey key = new CallbackKey(msg.id(), msg.from());
+ if (callbacks.containsKey(key))
+ {
+ CallbackContext callback = callbacks.remove(key);
+ if (callback == null)
+ return;
+ try
+ {
+ if (msg.isFailureResponse())
+ callback.onFailure(msg.from(), (RequestFailureReason) msg.payload);
+ else callback.onResponse(msg);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }
+ else
+ {
+ try
+ {
+ //noinspection unchecked
+ onMessage.doVerb(msg);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }
+ }
+
+ @SuppressWarnings("rawtypes")
+ public static class SimpleVerbHandler implements IVerbHandler
+ {
+ private final Map<Verb, IVerbHandler<?>> handlers;
+
+ public SimpleVerbHandler(Map<Verb, IVerbHandler<?>> handlers)
+ {
+ this.handlers = handlers;
+ }
+
+ @Override
+ public void doVerb(Message msg) throws IOException
+ {
+ IVerbHandler<?> handler = handlers.get(msg.verb());
+ if (handler == null)
+ throw new AssertionError("Unexpected verb: " + msg.verb());
+ //noinspection unchecked
+ handler.doVerb(msg);
+ }
+ }
+
+ private static class CallbackContext
+ {
+ @SuppressWarnings("rawtypes")
+ final RequestCallback callback;
+
+ @SuppressWarnings("rawtypes")
+ private CallbackContext(RequestCallback callback)
+ {
+ this.callback = Objects.requireNonNull(callback);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ public void onResponse(Message msg)
+ {
+ callback.onResponse(msg);
+ }
+
+ public void onFailure(InetAddressAndPort from, RequestFailureReason failure)
+ {
+ if (callback.invokeOnFailure()) callback.onFailure(from, failure);
+ }
+ }
+
+ private static class CallbackKey
+ {
+ private final long id;
+ private final InetAddressAndPort peer;
+
+ private CallbackKey(long id, InetAddressAndPort peer)
+ {
+ this.id = id;
+ this.peer = peer;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CallbackKey that = (CallbackKey) o;
+ return id == that.id && peer.equals(that.peer);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(id, peer);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "CallbackKey{" +
+ "id=" + id +
+ ", peer=" + peer +
+ '}';
+ }
+ }
+}
diff --git a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
index b01a9fc..fb3ce47 100644
--- a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
+++ b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
@@ -155,7 +155,8 @@
{
SharedContext ctx = Mockito.mock(SharedContext.class, REJECT_ALL);
MessageDelivery messaging = Mockito.mock(MessageDelivery.class, REJECT_ALL);
- // allow the single method under test
+ // allow all retry methods and send with callback
+ Mockito.doCallRealMethod().when(messaging).sendWithRetries(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
Mockito.doNothing().when(messaging).sendWithCallback(Mockito.any(), Mockito.any(), Mockito.any());
IGossiper gossiper = Mockito.mock(IGossiper.class, REJECT_ALL);
Mockito.doReturn(RepairMessage.SUPPORTS_RETRY).when(gossiper).getReleaseVersion(Mockito.any());
@@ -205,7 +206,7 @@
{
before();
- sendMessageWithRetries(ctx, backoff(maxAttempts), always(), PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(maxAttempts), always(), PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK);
for (int i = 0; i < maxAttempts; i++)
callback(messaging).onFailure(ADDRESS, RequestFailureReason.TIMEOUT);
fn.test(maxAttempts, callback(messaging));