CEP-15 (Accord) Original and recover coordinators may hit a race condition with PreApply where reads and writes are interleaved, causing one of the coordinators to see the writes from the other
patch by David Capwell; reviewed by Ariel Weisberg for CASSANDRA-18422
diff --git a/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java b/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java
index e81186c..ff34236 100644
--- a/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java
+++ b/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java
@@ -809,6 +809,29 @@
}
@Override
+ public <T> AsyncChain<T> submit(Callable<T> task)
+ {
+ return new AsyncChains.Head<T>()
+ {
+ @Override
+ protected void start(BiConsumer<? super T, Throwable> callback)
+ {
+ enqueueAndRun(() -> {
+ try
+ {
+ callback.accept(task.call(), null);
+ }
+ catch (Throwable t)
+ {
+ logger.error("Uncaught exception", t);
+ callback.accept(null, t);
+ }
+ });
+ }
+ };
+ }
+
+ @Override
public void shutdown() {}
}
@@ -865,6 +888,12 @@
}
@Override
+ public <T> AsyncChain<T> submit(Callable<T> task)
+ {
+ return AsyncChains.ofCallable(executor, task);
+ }
+
+ @Override
public void shutdown()
{
executor.shutdown();
diff --git a/accord-core/src/main/java/accord/local/CommandStore.java b/accord-core/src/main/java/accord/local/CommandStore.java
index 65f8949..479f581 100644
--- a/accord-core/src/main/java/accord/local/CommandStore.java
+++ b/accord-core/src/main/java/accord/local/CommandStore.java
@@ -23,6 +23,7 @@
import accord.api.DataStore;
import accord.local.CommandStores.RangesForEpochHolder;
import accord.utils.async.AsyncChain;
+import accord.utils.async.AsyncExecutor;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -30,7 +31,7 @@
/**
* Single threaded internal shard of accord transaction metadata
*/
-public interface CommandStore
+public interface CommandStore extends AsyncExecutor
{
interface Factory
{
@@ -46,5 +47,12 @@
Agent agent();
AsyncChain<Void> execute(PreLoadContext context, Consumer<? super SafeCommandStore> consumer);
<T> AsyncChain<T> submit(PreLoadContext context, Function<? super SafeCommandStore, T> apply);
+
+ @Override
+ default void execute(Runnable command)
+ {
+ submit(command).begin(agent());
+ }
+
void shutdown();
}
diff --git a/accord-core/src/main/java/accord/local/Commands.java b/accord-core/src/main/java/accord/local/Commands.java
index e124462..66546a4 100644
--- a/accord-core/src/main/java/accord/local/Commands.java
+++ b/accord-core/src/main/java/accord/local/Commands.java
@@ -402,6 +402,7 @@
attrs = set(safeStore, command, attrs, coordinateRanges, executeRanges, shard, route, null, Check, partialDeps, command.hasBeen(Committed) ? Add : TrySet);
safeCommand.preapplied(attrs, executeAt, waitingOn, writes, result);
+ safeStore.notifyListeners(safeCommand);
logger.trace("{}: apply, status set to Executed with executeAt: {}, deps: {}", txnId, executeAt, partialDeps);
maybeExecute(safeStore, safeCommand, shard, true, true);
diff --git a/accord-core/src/main/java/accord/messages/ReadData.java b/accord-core/src/main/java/accord/messages/ReadData.java
index a4ebb4f..d7558d3 100644
--- a/accord-core/src/main/java/accord/messages/ReadData.java
+++ b/accord-core/src/main/java/accord/messages/ReadData.java
@@ -29,6 +29,7 @@
import accord.api.Data;
import accord.topology.Topologies;
import accord.utils.Invariants;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -53,12 +54,41 @@
return new ReadData(txnId, scope, executeAtEpoch, waitForEpoch);
}
}
+ private class ObsoleteTracker implements CommandListener
+ {
+ @Override
+ public void onChange(SafeCommandStore safeStore, SafeCommand safeCommand)
+ {
+ switch (safeCommand.current().status())
+ {
+ case PreApplied:
+ case Applied:
+ case Invalidated:
+ obsolete();
+ safeCommand.removeListener(this);
+ }
+ }
+ @Override
+ public PreLoadContext listenerPreLoadContext(TxnId caller)
+ {
+ return ReadData.this.listenerPreLoadContext(caller);
+ }
+
+ @Override
+ public boolean isTransient()
+ {
+ return true;
+ }
+ }
+
+ private final ObsoleteTracker obsoleteTracker = new ObsoleteTracker();
public final long executeAtEpoch;
public final Seekables<?, ?> readScope; // TODO (low priority, efficiency): this should be RoutingKeys, as we have the Keys locally, but for simplicity we use this to implement keys()
private final long waitForEpoch;
private Data data;
- private transient boolean isObsolete; // TODO (low priority, semantics): respond with the Executed result we have stored?
+ private enum State { PENDING, RETURNED, OBSOLETE }
+ private transient State state = State.PENDING; // TODO (low priority, semantics): respond with the Executed result we have stored?
private transient BitSet waitingOn;
private transient int waitingOnCount;
@@ -131,10 +161,8 @@
case ReadyToExecute:
}
- command = safeCommand.removeListener(this);
-
- if (!isObsolete)
- read(safeStore, command.asCommitted());
+ safeCommand.removeListener(this);
+ maybeRead(safeStore, safeCommand);
}
@Override
@@ -145,7 +173,7 @@
logger.trace("{}: setting up read with status {} on {}", txnId, status, safeStore);
switch (status) {
default:
- throw new AssertionError();
+ throw new AssertionError("Unknown status: " + status);
case Committed:
case NotWitnessed:
case PreAccepted:
@@ -166,18 +194,34 @@
case ReadyToExecute:
waitingOn.set(safeStore.commandStore().id());
++waitingOnCount;
- if (!isObsolete)
- read(safeStore, safeCommand.current().asCommitted());
+ maybeRead(safeStore, safeCommand);
return null;
case PreApplied:
case Applied:
case Invalidated:
- isObsolete = true;
+ state = State.OBSOLETE;
return Redundant;
}
}
+ private void maybeRead(SafeCommandStore safeStore, SafeCommand safeCommand)
+ {
+ switch (state)
+ {
+ case PENDING:
+ read(safeStore, safeCommand, safeCommand.current().asCommitted());
+ break;
+ case OBSOLETE:
+ // nothing to see here
+ break;
+ case RETURNED:
+ throw new IllegalStateException("ReadOk was sent, yet ack called again");
+ default:
+ throw new AssertionError("Unknown state: " + state);
+ }
+ }
+
@Override
public ReadNack reduce(ReadNack r1, ReadNack r2)
{
@@ -219,12 +263,27 @@
// and prevents races where we respond before dispatching all the required reads (if the reads are
// completing faster than the reads can be setup on all required shards)
if (-1 == --waitingOnCount)
- node.reply(replyTo, replyContext, new ReadOk(data));
+ {
+ switch (state)
+ {
+ case RETURNED:
+ throw new IllegalStateException("ReadOk was sent, yet ack called again");
+ case OBSOLETE:
+ logger.debug("After the read completed for txn {}, the result was marked obsolete", txnId);
+ break;
+ case PENDING:
+ state = State.RETURNED;
+ node.reply(replyTo, replyContext, new ReadOk(data));
+ break;
+ default:
+ throw new AssertionError("Unknown state: " + state);
+ }
+ }
}
private synchronized void readComplete(CommandStore commandStore, Data result)
{
- Invariants.checkState(waitingOn.get(commandStore.id()), "Waiting on does not contain store %d; waitingOn=%s", commandStore.id(), waitingOn);
+ Invariants.checkState(waitingOn.get(commandStore.id()), "Txn %s's waiting on does not contain store %d; waitingOn=%s", txnId, commandStore.id(), waitingOn);
logger.trace("{}: read completed on {}", txnId, commandStore);
if (result != null)
data = data == null ? result : data.merge(result);
@@ -233,8 +292,9 @@
ack();
}
- private void read(SafeCommandStore safeStore, Command.Committed command)
+ private void read(SafeCommandStore safeStore, SafeCommand safeCommand, Command.Committed command)
{
+ safeCommand.addListener(obsoleteTracker);
CommandStore unsafeStore = safeStore.commandStore();
logger.trace("{}: executing read", command.txnId());
command.read(safeStore).begin((next, throwable) -> {
@@ -249,11 +309,11 @@
});
}
- void obsolete()
+ synchronized void obsolete()
{
- if (!isObsolete)
+ if (state == State.PENDING)
{
- isObsolete = true;
+ state = State.OBSOLETE;
node.reply(replyTo, replyContext, Redundant);
}
}
diff --git a/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java b/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java
new file mode 100644
index 0000000..50d2116
--- /dev/null
+++ b/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accord.utils.async;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+public interface AsyncExecutor extends Executor
+{
+ default AsyncChain<?> submit(Runnable task)
+ {
+ return submit(Executors.callable(task));
+ }
+
+ default <T> AsyncChain<T> submit(Runnable task, T result)
+ {
+ return submit(Executors.callable(task, result));
+ }
+
+ <T> AsyncChain<T> submit(Callable<T> task);
+}
diff --git a/accord-core/src/main/java/accord/utils/async/AsyncResult.java b/accord-core/src/main/java/accord/utils/async/AsyncResult.java
index 59f97b7..3269f7a 100644
--- a/accord-core/src/main/java/accord/utils/async/AsyncResult.java
+++ b/accord-core/src/main/java/accord/utils/async/AsyncResult.java
@@ -79,7 +79,11 @@
default void setFailure(Throwable throwable)
{
if (!tryFailure(throwable))
- throw new IllegalStateException("Result has already been set on " + this);
+ {
+ IllegalStateException e = new IllegalStateException("Result has already been set on " + this);
+ e.addSuppressed(throwable);
+ throw e;
+ }
}
default BiConsumer<V, Throwable> settingCallback()
diff --git a/accord-core/src/test/java/accord/Utils.java b/accord-core/src/test/java/accord/Utils.java
index 42dab7d..ce5a80c 100644
--- a/accord-core/src/test/java/accord/Utils.java
+++ b/accord-core/src/test/java/accord/Utils.java
@@ -18,7 +18,16 @@
package accord;
+import accord.api.MessageSink;
+import accord.api.Scheduler;
+import accord.impl.InMemoryCommandStores;
+import accord.impl.IntKey;
+import accord.impl.SimpleProgressLog;
import accord.impl.SizeOfIntersectionSorter;
+import accord.impl.TestAgent;
+import accord.impl.mock.MockCluster;
+import accord.impl.mock.MockConfigurationService;
+import accord.local.ShardDistributor;
import accord.primitives.Range;
import accord.local.Node;
import accord.impl.mock.MockStore;
@@ -28,7 +37,11 @@
import accord.topology.Topology;
import accord.primitives.Txn;
import accord.primitives.Keys;
+import accord.utils.DefaultRandom;
+import accord.utils.EpochFunction;
import accord.utils.Invariants;
+import accord.utils.ThreadPoolScheduler;
+
import com.google.common.collect.Sets;
import java.util.ArrayList;
@@ -117,4 +130,22 @@
{
return new Topologies.Multi(SizeOfIntersectionSorter.SUPPLIER, topologies);
}
+
+ public static Node createNode(Node.Id nodeId, Topology topology, MessageSink messageSink, MockCluster.Clock clock)
+ {
+ MockStore store = new MockStore();
+ Scheduler scheduler = new ThreadPoolScheduler();
+ return new Node(nodeId,
+ messageSink,
+ new MockConfigurationService(messageSink, EpochFunction.noop(), topology),
+ clock,
+ () -> store,
+ new ShardDistributor.EvenSplit(8, ignore -> new IntKey.Splitter()),
+ new TestAgent(),
+ new DefaultRandom(),
+ scheduler,
+ SizeOfIntersectionSorter.SUPPLIER,
+ SimpleProgressLog::new,
+ InMemoryCommandStores.Synchronized::new);
+ }
}
diff --git a/accord-core/src/test/java/accord/burn/BurnTest.java b/accord-core/src/test/java/accord/burn/BurnTest.java
index 18a923d..78269a7 100644
--- a/accord-core/src/test/java/accord/burn/BurnTest.java
+++ b/accord-core/src/test/java/accord/burn/BurnTest.java
@@ -32,32 +32,39 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
+import java.util.function.Function;
import java.util.function.LongSupplier;
import java.util.function.Predicate;
-import accord.utils.DefaultRandom;
-import accord.utils.RandomSource;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import accord.api.Key;
import accord.impl.IntHashKey;
-import accord.impl.basic.Cluster;
-import accord.impl.basic.PropagatingPendingQueue;
-import accord.impl.basic.RandomDelayQueue.Factory;
import accord.impl.TopologyFactory;
+import accord.impl.basic.Cluster;
import accord.impl.basic.Packet;
import accord.impl.basic.PendingQueue;
+import accord.impl.basic.PropagatingPendingQueue;
+import accord.impl.basic.RandomDelayQueue.Factory;
+import accord.impl.basic.SimulatedDelayedExecutorService;
import accord.impl.list.ListQuery;
import accord.impl.list.ListRead;
import accord.impl.list.ListRequest;
import accord.impl.list.ListResult;
import accord.impl.list.ListUpdate;
+import accord.local.CommandStore;
import accord.local.Node.Id;
-import accord.api.Key;
-import accord.primitives.*;
+import accord.primitives.Keys;
+import accord.primitives.Range;
+import accord.primitives.Ranges;
+import accord.primitives.Txn;
+import accord.utils.DefaultRandom;
+import accord.utils.RandomSource;
+import accord.utils.async.AsyncExecutor;
import accord.verify.StrictSerializabilityVerifier;
-import org.junit.jupiter.api.Test;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import static accord.impl.IntHashKey.forHash;
import static accord.utils.Utils.toArray;
@@ -65,7 +72,7 @@
{
private static final Logger logger = LoggerFactory.getLogger(BurnTest.class);
- static List<Packet> generate(RandomSource random, List<Id> clients, List<Id> nodes, int keyCount, int operations)
+ static List<Packet> generate(RandomSource random, Function<? super CommandStore, AsyncExecutor> executor, List<Id> clients, List<Id> nodes, int keyCount, int operations)
{
List<Key> keys = new ArrayList<>();
for (int i = 0 ; i < keyCount ; ++i)
@@ -73,6 +80,7 @@
List<Packet> packets = new ArrayList<>();
int[] next = new int[keyCount];
+ double readInCommandStore = random.nextDouble();
for (int count = 0 ; count < operations ; ++count)
{
@@ -90,7 +98,7 @@
requestRanges.add(IntHashKey.range(forHash(i), forHash(j)));
}
Ranges ranges = Ranges.of(requestRanges.toArray(new Range[0]));
- ListRead read = new ListRead(ranges, ranges);
+ ListRead read = new ListRead(random.decide(readInCommandStore) ? Function.identity() : executor, ranges, ranges);
ListQuery query = new ListQuery(client, count);
ListRequest request = new ListRequest(new Txn.InMemory(ranges, read, query, null));
packets.add(new Packet(client, node, count, request));
@@ -107,7 +115,7 @@
while (readCount-- > 0)
requestKeys.add(randomKey(random, keys, requestKeys));
- ListUpdate update = isWrite ? new ListUpdate() : null;
+ ListUpdate update = isWrite ? new ListUpdate(executor) : null;
while (writeCount-- > 0)
{
int i = randomKeyIndex(random, keys, update.keySet());
@@ -117,7 +125,7 @@
Keys readKeys = new Keys(requestKeys);
if (isWrite)
requestKeys.addAll(update.keySet());
- ListRead read = new ListRead(readKeys, new Keys(requestKeys));
+ ListRead read = new ListRead(random.decide(readInCommandStore) ? Function.identity() : executor, readKeys, new Keys(requestKeys));
ListQuery query = new ListQuery(client, count);
ListRequest request = new ListRequest(new Txn.InMemory(new Keys(requestKeys), read, query, update));
packets.add(new Packet(client, node, count, request));
@@ -191,10 +199,12 @@
{
List<Throwable> failures = Collections.synchronizedList(new ArrayList<>());
PendingQueue queue = new PropagatingPendingQueue(failures, new Factory(random).get());
+ SimulatedDelayedExecutorService globalExecutor = new SimulatedDelayedExecutorService(queue, random.fork());
StrictSerializabilityVerifier strictSerializable = new StrictSerializabilityVerifier(keyCount);
+ Function<CommandStore, AsyncExecutor> executor = ignore -> globalExecutor;
- Packet[] requests = toArray(generate(random, clients, nodes, keyCount, operations), Packet[]::new);
+ Packet[] requests = toArray(generate(random, executor, clients, nodes, keyCount, operations), Packet[]::new);
int[] starts = new int[requests.length];
Packet[] replies = new Packet[requests.length];
diff --git a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java b/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java
index e400d08..850f717 100644
--- a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java
+++ b/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java
@@ -18,20 +18,15 @@
package accord.impl.basic;
+import java.util.concurrent.Callable;
+import java.util.concurrent.TimeUnit;
+
import accord.burn.random.FrequentLargeRange;
import accord.burn.random.RandomLong;
import accord.burn.random.RandomWalkRange;
import accord.utils.RandomSource;
-import java.util.Collections;
-import java.util.List;
-import java.util.concurrent.AbstractExecutorService;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Executors;
-import java.util.concurrent.FutureTask;
-import java.util.concurrent.TimeUnit;
-
-public class SimulatedDelayedExecutorService extends AbstractExecutorService
+public class SimulatedDelayedExecutorService extends TaskExecutorService
{
private final PendingQueue pending;
private final RandomSource random;
@@ -44,10 +39,9 @@
// this is different from Apache Cassandra Simulator as this is computed differently for each executor
// rather than being a global config
double ratio = random.nextInt(1, 11) / 100.0D;
- this.jitterInNano = new FrequentLargeRange(
- new RandomWalkRange(random, microToNanos(0), microToNanos(50)),
- new RandomWalkRange(random, microToNanos(50), msToNanos(5)),
- ratio);
+ this.jitterInNano = new FrequentLargeRange(new RandomWalkRange(random, microToNanos(0), microToNanos(50)),
+ new RandomWalkRange(random, microToNanos(50), msToNanos(5)),
+ ratio);
}
private static int msToNanos(int value)
@@ -61,63 +55,15 @@
}
@Override
- protected <T> Task<T> newTaskFor(Runnable runnable, T value)
+ public void execute(Task<?> task)
{
- return newTaskFor(Executors.callable(runnable, value));
+ pending.add(task, jitterInNano.getLong(random), TimeUnit.NANOSECONDS);
}
- @Override
- protected <T> Task<T> newTaskFor(Callable<T> callable)
+ public <T> Task<T> submit(Callable<T> fn, long delay, TimeUnit unit)
{
- return new Task<>(callable);
+ Task<T> task = newTaskFor(fn);
+ pending.add(task, jitterInNano.getLong(random) + unit.toNanos(delay), TimeUnit.NANOSECONDS);
+ return task;
}
-
- private Task<?> newTaskFor(Runnable command)
- {
- return command instanceof Task ? (Task<?>) command : newTaskFor(command, null);
- }
-
- @Override
- public void execute(Runnable command)
- {
- pending.add(newTaskFor(command), jitterInNano.getLong(random), TimeUnit.NANOSECONDS);
- }
-
- @Override
- public void shutdown()
- {
- }
-
- @Override
- public List<Runnable> shutdownNow()
- {
- return Collections.emptyList();
- }
-
- @Override
- public boolean isShutdown()
- {
- return false;
- }
-
- @Override
- public boolean isTerminated()
- {
- return false;
- }
-
- @Override
- public boolean awaitTermination(long timeout, TimeUnit unit)
- {
- return false;
- }
-
-
- private static class Task<T> extends FutureTask<T> implements Pending
- {
- public Task(Callable<T> fn)
- {
- super(fn);
- }
- }
-}
+}
\ No newline at end of file
diff --git a/accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java b/accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java
new file mode 100644
index 0000000..2ab766b
--- /dev/null
+++ b/accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accord.impl.basic;
+
+import java.util.List;
+import java.util.concurrent.AbstractExecutorService;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.RunnableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import accord.utils.async.AsyncExecutor;
+import accord.utils.async.AsyncResults;
+
+public abstract class TaskExecutorService extends AbstractExecutorService implements AsyncExecutor
+{
+ public static class Task<T> extends AsyncResults.SettableResult<T> implements Pending, RunnableFuture<T>
+ {
+ private final Callable<T> fn;
+
+ public Task(Callable<T> fn)
+ {
+ this.fn = fn;
+ }
+
+ @Override
+ public void run()
+ {
+ try
+ {
+ setSuccess(fn.call());
+ }
+ catch (Throwable t)
+ {
+ setFailure(t);
+ }
+ }
+
+ @Override
+ public boolean cancel(boolean mayInterruptIfRunning)
+ {
+ return false;
+ }
+
+ @Override
+ public boolean isCancelled()
+ {
+ return false;
+ }
+
+ @Override
+ public T get() throws InterruptedException, ExecutionException
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException
+ {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ @Override
+ protected <T> Task<T> newTaskFor(Runnable runnable, T value)
+ {
+ return newTaskFor(Executors.callable(runnable, value));
+ }
+
+ @Override
+ protected <T> Task<T> newTaskFor(Callable<T> callable)
+ {
+ return new Task<>(callable);
+ }
+
+ private Task<?> newTaskFor(Runnable command)
+ {
+ return command instanceof Task ? (Task<?>) command : newTaskFor(command, null);
+ }
+
+ protected abstract void execute(Task<?> task);
+
+ @Override
+ public final void execute(Runnable command)
+ {
+ execute(newTaskFor(command));
+ }
+
+ @Override
+ public Task<?> submit(Runnable task)
+ {
+ return (Task<?>) super.submit(task);
+ }
+
+ @Override
+ public <T> Task<T> submit(Runnable task, T result)
+ {
+ return (Task<T>) super.submit(task, result);
+ }
+
+ @Override
+ public <T> Task<T> submit(Callable<T> task)
+ {
+ return (Task<T>) super.submit(task);
+ }
+
+ @Override
+ public void shutdown()
+ {
+
+ }
+
+ @Override
+ public List<Runnable> shutdownNow()
+ {
+ return null;
+ }
+
+ @Override
+ public boolean isShutdown()
+ {
+ return false;
+ }
+
+ @Override
+ public boolean isTerminated()
+ {
+ return false;
+ }
+
+ @Override
+ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException
+ {
+ return false;
+ }
+}
\ No newline at end of file
diff --git a/accord-core/src/test/java/accord/impl/list/ListRead.java b/accord-core/src/test/java/accord/impl/list/ListRead.java
index c8ea05c..67435b9 100644
--- a/accord-core/src/test/java/accord/impl/list/ListRead.java
+++ b/accord-core/src/test/java/accord/impl/list/ListRead.java
@@ -18,29 +18,38 @@
package accord.impl.list;
-import accord.api.*;
-import accord.local.SafeCommandStore;
-import accord.primitives.*;
-import accord.primitives.Ranges;
-import accord.primitives.Keys;
-import accord.primitives.Timestamp;
-import accord.primitives.Txn;
-import accord.utils.async.AsyncChain;
-import accord.utils.async.AsyncChains;
+import java.util.Map;
+import java.util.function.Function;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.Map;
+import accord.api.Data;
+import accord.api.DataStore;
+import accord.api.Key;
+import accord.api.Read;
+import accord.local.CommandStore;
+import accord.local.SafeCommandStore;
+import accord.primitives.Range;
+import accord.primitives.Ranges;
+import accord.primitives.Seekable;
+import accord.primitives.Seekables;
+import accord.primitives.Timestamp;
+import accord.primitives.Txn;
+import accord.utils.async.AsyncChain;
+import accord.utils.async.AsyncExecutor;
public class ListRead implements Read
{
private static final Logger logger = LoggerFactory.getLogger(ListRead.class);
+ private final Function<? super CommandStore, AsyncExecutor> executor;
public final Seekables<?, ?> readKeys;
public final Seekables<?, ?> keys;
- public ListRead(Seekables<?, ?> readKeys, Seekables<?, ?> keys)
+ public ListRead(Function<? super CommandStore, AsyncExecutor> executor, Seekables<?, ?> readKeys, Seekables<?, ?> keys)
{
+ this.executor = executor;
this.readKeys = readKeys;
this.keys = keys;
}
@@ -55,32 +64,34 @@
public AsyncChain<Data> read(Seekable key, Txn.Kind kind, SafeCommandStore commandStore, Timestamp executeAt, DataStore store)
{
ListStore s = (ListStore)store;
- ListData result = new ListData();
- switch (key.domain())
- {
- default: throw new AssertionError();
- case Key:
- int[] data = s.get((Key)key);
- logger.trace("READ on {} at {} key:{} -> {}", s.node, executeAt, key, data);
- result.put((Key)key, data);
- break;
- case Range:
- for (Map.Entry<Key, int[]> e : s.get((Range)key))
- result.put(e.getKey(), e.getValue());
- }
- return AsyncChains.success(result);
+ return executor.apply(commandStore.commandStore()).submit(() -> {
+ ListData result = new ListData();
+ switch (key.domain())
+ {
+ default: throw new AssertionError();
+ case Key:
+ int[] data = s.get((Key)key);
+ logger.trace("READ on {} at {} key:{} -> {}", s.node, executeAt, key, data);
+ result.put((Key)key, data);
+ break;
+ case Range:
+ for (Map.Entry<Key, int[]> e : s.get((Range)key))
+ result.put(e.getKey(), e.getValue());
+ }
+ return result;
+ });
}
@Override
public Read slice(Ranges ranges)
{
- return new ListRead(readKeys.slice(ranges), keys.slice(ranges));
+ return new ListRead(executor, readKeys.slice(ranges), keys.slice(ranges));
}
@Override
public Read merge(Read other)
{
- return new ListRead(((Seekables)readKeys).with(((ListRead)other).readKeys), ((Seekables)keys).with(((ListRead)other).keys));
+ return new ListRead(executor, ((Seekables)readKeys).with(((ListRead)other).readKeys), ((Seekables)keys).with(((ListRead)other).keys));
}
@Override
diff --git a/accord-core/src/test/java/accord/impl/list/ListUpdate.java b/accord-core/src/test/java/accord/impl/list/ListUpdate.java
index 055d1ea..6461e38 100644
--- a/accord-core/src/test/java/accord/impl/list/ListUpdate.java
+++ b/accord-core/src/test/java/accord/impl/list/ListUpdate.java
@@ -21,17 +21,27 @@
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
+import java.util.function.Function;
import java.util.stream.Collectors;
-import accord.api.Key;
import accord.api.Data;
+import accord.api.Key;
import accord.api.Update;
-import accord.primitives.Ranges;
+import accord.local.CommandStore;
import accord.primitives.Keys;
+import accord.primitives.Ranges;
import accord.primitives.Seekables;
+import accord.utils.async.AsyncExecutor;
public class ListUpdate extends TreeMap<Key, Integer> implements Update
{
+ private final Function<? super CommandStore, AsyncExecutor> executor;
+
+ public ListUpdate(Function<? super CommandStore, AsyncExecutor> executor)
+ {
+ this.executor = executor;
+ }
+
@Override
public Seekables<?, ?> keys()
{
@@ -41,7 +51,7 @@
@Override
public ListWrite apply(Data read)
{
- ListWrite write = new ListWrite();
+ ListWrite write = new ListWrite(executor);
Map<Key, int[]> data = (ListData)read;
for (Map.Entry<Key, Integer> e : entrySet())
write.put(e.getKey(), append(data.get(e.getKey()), e.getValue()));
@@ -51,7 +61,7 @@
@Override
public Update slice(Ranges ranges)
{
- ListUpdate result = new ListUpdate();
+ ListUpdate result = new ListUpdate(executor);
for (Map.Entry<Key, Integer> e : entrySet())
{
if (ranges.contains(e.getKey()))
@@ -63,7 +73,7 @@
@Override
public Update merge(Update other)
{
- ListUpdate result = new ListUpdate();
+ ListUpdate result = new ListUpdate(executor);
result.putAll(this);
result.putAll((ListUpdate) other);
return result;
diff --git a/accord-core/src/test/java/accord/impl/list/ListWrite.java b/accord-core/src/test/java/accord/impl/list/ListWrite.java
index 20a2a87..41fefe9 100644
--- a/accord-core/src/test/java/accord/impl/list/ListWrite.java
+++ b/accord-core/src/test/java/accord/impl/list/ListWrite.java
@@ -20,34 +20,47 @@
import java.util.Arrays;
import java.util.TreeMap;
+import java.util.function.Function;
import java.util.stream.Collectors;
-import accord.api.Key;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import accord.api.DataStore;
+import accord.api.Key;
import accord.api.Write;
+import accord.local.CommandStore;
import accord.local.SafeCommandStore;
import accord.primitives.Seekable;
import accord.primitives.Timestamp;
import accord.primitives.Writes;
import accord.utils.Timestamped;
import accord.utils.async.AsyncChain;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import accord.utils.async.AsyncExecutor;
public class ListWrite extends TreeMap<Key, int[]> implements Write
{
private static final Logger logger = LoggerFactory.getLogger(ListWrite.class);
+ private final Function<? super CommandStore, AsyncExecutor> executor;
+
+ public ListWrite(Function<? super CommandStore, AsyncExecutor> executor)
+ {
+ this.executor = executor;
+ }
+
@Override
public AsyncChain<Void> apply(Seekable key, SafeCommandStore commandStore, Timestamp executeAt, DataStore store)
{
ListStore s = (ListStore) store;
if (!containsKey(key))
return Writes.SUCCESS;
- int[] data = get(key);
- s.data.merge((Key)key, new Timestamped<>(executeAt, data), Timestamped::merge);
- logger.trace("WRITE on {} at {} key:{} -> {}", s.node, executeAt, key, data);
- return Writes.SUCCESS;
+ return executor.apply(commandStore.commandStore()).submit(() -> {
+ int[] data = get(key);
+ s.data.merge((Key)key, new Timestamped<>(executeAt, data), Timestamped::merge);
+ logger.trace("WRITE on {} at {} key:{} -> {}", s.node, executeAt, key, data);
+ return null;
+ });
}
@Override
diff --git a/accord-core/src/test/java/accord/local/CheckedCommands.java b/accord-core/src/test/java/accord/local/CheckedCommands.java
new file mode 100644
index 0000000..a8be25e
--- /dev/null
+++ b/accord-core/src/test/java/accord/local/CheckedCommands.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accord.local;
+
+import javax.annotation.Nullable;
+
+import accord.api.Result;
+import accord.api.RoutingKey;
+import accord.primitives.Ballot;
+import accord.primitives.PartialDeps;
+import accord.primitives.PartialRoute;
+import accord.primitives.PartialTxn;
+import accord.primitives.Route;
+import accord.primitives.Seekables;
+import accord.primitives.Timestamp;
+import accord.primitives.TxnId;
+import accord.primitives.Writes;
+
+public class CheckedCommands
+{
+ public static void preaccept(SafeCommandStore safeStore, TxnId txnId, PartialTxn partialTxn, Route<?> route, @Nullable RoutingKey progressKey)
+ {
+ Commands.AcceptOutcome result = Commands.preaccept(safeStore, txnId, partialTxn, route, progressKey);
+ if (result != Commands.AcceptOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result);
+ }
+
+ public static void accept(SafeCommandStore safeStore, TxnId txnId, Ballot ballot, PartialRoute<?> route, Seekables<?, ?> keys, @Nullable RoutingKey progressKey, Timestamp executeAt, PartialDeps partialDeps)
+ {
+ Commands.AcceptOutcome result = Commands.accept(safeStore, txnId, ballot, route, keys, progressKey, executeAt, partialDeps);
+ if (result != Commands.AcceptOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result);
+ }
+
+ public static void commit(SafeCommandStore safeStore, TxnId txnId, Route<?> route, @Nullable RoutingKey progressKey, @Nullable PartialTxn partialTxn, Timestamp executeAt, PartialDeps partialDeps)
+ {
+ Commands.CommitOutcome result = Commands.commit(safeStore, txnId, route, progressKey, partialTxn, executeAt, partialDeps);
+ if (result != Commands.CommitOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result);
+ }
+
+ public static void apply(SafeCommandStore safeStore, TxnId txnId, long untilEpoch, Route<?> route, Timestamp executeAt, @Nullable PartialDeps partialDeps, Writes writes, Result result)
+ {
+ Commands.ApplyOutcome outcome = Commands.apply(safeStore, txnId, untilEpoch, route, executeAt, partialDeps, writes, result);
+ if (outcome != Commands.ApplyOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + outcome);
+ }
+}
diff --git a/accord-core/src/test/java/accord/messages/PreAcceptTest.java b/accord-core/src/test/java/accord/messages/PreAcceptTest.java
index ae339a1..29be803 100644
--- a/accord-core/src/test/java/accord/messages/PreAcceptTest.java
+++ b/accord-core/src/test/java/accord/messages/PreAcceptTest.java
@@ -59,29 +59,10 @@
private static final Id ID3 = id(3);
private static final List<Id> IDS = listOf(ID1, ID2, ID3);
private static final Topology TOPOLOGY = TopologyFactory.toTopology(IDS, 3, IntKey.range(0, 100));
- private static final Ranges RANGE = Ranges.single(IntKey.range(0, 100));
private static final Ranges FULL_RANGE = Ranges.single(IntKey.range(routing(Integer.MIN_VALUE), routing(Integer.MAX_VALUE)));
private static final ReplyContext REPLY_CONTEXT = Network.replyCtxFor(0);
- private static Node createNode(Id nodeId, MessageSink messageSink, Clock clock)
- {
- MockStore store = new MockStore();
- Scheduler scheduler = new ThreadPoolScheduler();
- return new Node(nodeId,
- messageSink,
- new MockConfigurationService(messageSink, EpochFunction.noop(), TOPOLOGY),
- clock,
- () -> store,
- new ShardDistributor.EvenSplit(8, ignore -> new IntKey.Splitter()),
- new TestAgent(),
- new DefaultRandom(),
- scheduler,
- SizeOfIntersectionSorter.SUPPLIER,
- SimpleProgressLog::new,
- InMemoryCommandStores.Synchronized::new);
- }
-
private static PreAccept preAccept(TxnId txnId, Txn txn, RoutingKey homeKey)
{
FullRoute<?> route = txn.keys().toRoute(homeKey);
@@ -98,7 +79,7 @@
{
RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE);
Clock clock = new Clock(100);
- Node node = createNode(ID1, messageSink, clock);
+ Node node = createNode(ID1, TOPOLOGY, messageSink, clock);
messageSink.clearHistory();
try
@@ -137,7 +118,7 @@
{
RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE);
Clock clock = new Clock(100);
- Node node = createNode(ID1, messageSink, clock);
+ Node node = createNode(ID1, TOPOLOGY, messageSink, clock);
try
{
Raw key = IntKey.key(10);
@@ -165,7 +146,7 @@
{
RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE);
Clock clock = new Clock(100);
- Node node = createNode(ID1, messageSink, clock);
+ Node node = createNode(ID1, TOPOLOGY, messageSink, clock);
try
{
Raw key1 = IntKey.key(10);
@@ -201,7 +182,7 @@
{
RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE);
Clock clock = new Clock(100);
- Node node = createNode(ID1, messageSink, clock);
+ Node node = createNode(ID1, TOPOLOGY, messageSink, clock);
messageSink.clearHistory();
Raw key = IntKey.key(10);
try
@@ -228,7 +209,7 @@
{
RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE);
Clock clock = new Clock(100);
- Node node = createNode(ID1, messageSink, clock);
+ Node node = createNode(ID1, TOPOLOGY, messageSink, clock);
try
{
diff --git a/accord-core/src/test/java/accord/messages/ReadDataTest.java b/accord-core/src/test/java/accord/messages/ReadDataTest.java
new file mode 100644
index 0000000..8a6211e
--- /dev/null
+++ b/accord-core/src/test/java/accord/messages/ReadDataTest.java
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package accord.messages;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import org.junit.jupiter.api.Test;
+
+import accord.Utils;
+import accord.api.Data;
+import accord.api.Key;
+import accord.api.MessageSink;
+import accord.api.Query;
+import accord.api.Read;
+import accord.api.Result;
+import accord.api.RoutingKey;
+import accord.api.Update;
+import accord.api.Write;
+import accord.impl.IntKey;
+import accord.impl.TopologyFactory;
+import accord.impl.mock.MockCluster;
+import accord.local.CheckedCommands;
+import accord.local.Command;
+import accord.local.CommandStore;
+import accord.local.Node;
+import accord.local.PreLoadContext;
+import accord.local.SafeCommand;
+import accord.primitives.Ballot;
+import accord.primitives.FullRoute;
+import accord.primitives.Keys;
+import accord.primitives.PartialDeps;
+import accord.primitives.PartialRoute;
+import accord.primitives.PartialTxn;
+import accord.primitives.Range;
+import accord.primitives.Ranges;
+import accord.primitives.Routable;
+import accord.primitives.Timestamp;
+import accord.primitives.Txn;
+import accord.primitives.TxnId;
+import accord.primitives.Writes;
+import accord.topology.Topologies;
+import accord.topology.Topology;
+import accord.utils.async.AsyncChain;
+import accord.utils.async.AsyncChains;
+import accord.utils.async.AsyncResults;
+import org.assertj.core.api.Assertions;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static accord.Utils.createNode;
+import static accord.Utils.id;
+import static accord.utils.Utils.listOf;
+import static org.mockito.ArgumentMatchers.any;
+
+class ReadDataTest
+{
+ private static final Node.Id ID1 = id(1);
+ private static final Node.Id ID2 = id(2);
+ private static final Node.Id ID3 = id(3);
+ private static final List<Node.Id> IDS = listOf(ID1, ID2, ID3);
+ private static final Range RANGE = IntKey.range(0, 100);
+ private static final Ranges RANGES = Ranges.single(RANGE);
+ private static final Topology TOPOLOGY = TopologyFactory.toTopology(IDS, 3, RANGE);
+ private static final Topologies TOPOLOGIES = Utils.topologies(TOPOLOGY);
+
+ private void test(Consumer<State> fn)
+ {
+ MessageSink sink = Mockito.mock(MessageSink.class);
+ Node node = createNode(ID1, TOPOLOGY, sink, new MockCluster.Clock(100));
+
+ TxnId txnId = node.nextTxnId(Txn.Kind.Write, Routable.Domain.Key);
+ Keys keys = Keys.of(IntKey.key(1), IntKey.key(43));
+
+ AsyncResults.SettableResult<Data> readResult = new AsyncResults.SettableResult<>();
+
+ Read read = Mockito.mock(Read.class);
+ Mockito.when(read.slice(any())).thenReturn(read);
+ Mockito.when(read.merge(any())).thenReturn(read);
+ Mockito.when(read.read(any(), any(), any(), any(), any())).thenAnswer(new Answer<AsyncChain<Data>>()
+ {
+ private boolean called = false;
+ @Override
+ public AsyncChain<Data> answer(InvocationOnMock ignore) throws Throwable
+ {
+ if (called) throw new IllegalStateException("Multiple calls");
+ return readResult;
+ }
+ });
+ Query query = Mockito.mock(Query.class);
+ Update update = Mockito.mock(Update.class);
+ Mockito.when(update.slice(any())).thenReturn(update);
+
+ Txn txn = new Txn.InMemory(keys, read, query, update);
+ PartialTxn partialTxn = txn.slice(RANGES, true);
+
+ fn.accept(new State(node, sink, txnId, partialTxn, readResult));
+ }
+
+ @Test
+ public void readyToExecuteObsoleteFromTracker()
+ {
+ // status=ReadyToExecute, so read will happen right away; obsolete marked by ObsoleteTracker
+ test(state -> {
+ state.readyToExecute();
+
+ ReplyContext replyContext = state.process();
+ Mockito.verifyNoInteractions(state.sink);
+
+ state.apply();
+ state.readResult.setSuccess(Mockito.mock(Data.class));
+ Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
+ });
+ }
+
+ @Test
+ public void commitObsoleteFromTracker()
+ {
+ // status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener
+ test(state -> {
+ state.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
+ CheckedCommands.preaccept(safe, state.txnId, state.partialTxn, state.route, state.progressKey);
+ CheckedCommands.accept(safe, state.txnId, Ballot.ZERO, state.partialRoute, state.partialTxn.keys(), state.progressKey, state.executeAt, state.deps);
+
+ SafeCommand safeCommand = safe.command(state.txnId);
+ safeCommand.commit(safeCommand.current(), state.executeAt, Command.WaitingOn.EMPTY);
+ })));
+
+ ReplyContext replyContext = state.process();
+
+ Mockito.verifyNoInteractions(state.sink);
+
+ state.apply();
+ state.readResult.setSuccess(Mockito.mock(Data.class));
+
+ Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
+ });
+ }
+
+ @Test
+ public void mapReduceMarksObsolete()
+ {
+ // status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener
+ test(state -> {
+ List<CommandStore> stores = stores(state);
+ // this test is a bit implementation specific... so if implementations change this may need an update
+ // since mapReduceConsume walks the store in id order, by making sure the stores involved in this test
+ // are in the "right" order, can make sure to hit a very specific edge case
+ Collections.sort(stores, Comparator.comparingInt(CommandStore::id));
+ CommandStore store = stores.get(0);
+
+ // ack doesn't get called due to waitingOnCount not being -1, can only happen once
+ // the process command completes
+ state.readResult.setSuccess(Mockito.mock(Data.class));
+ state.readyToExecute(store);
+
+ store = stores.get(1);
+ check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
+ SafeCommand command = safe.command(state.txnId);
+ command.commitInvalidated(command.current(), state.executeAt);
+ }));
+
+ ReplyContext replyContext = state.process();
+
+ Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
+ });
+ }
+
+ @Test
+ public void mapReduceAllStageMarksObsolete()
+ {
+ test(state -> {
+ List<CommandStore> stores = stores(state);
+ stores.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
+ SafeCommand command = safe.command(state.txnId);
+ command.commitInvalidated(command.current(), state.executeAt);
+ })));
+ ReplyContext replyContext = state.process();
+
+ Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
+ });
+ }
+
+ private static List<CommandStore> stores(State state)
+ {
+ List<CommandStore> stores = new ArrayList<>(2);
+ state.forEach(stores::add);
+ Assertions.assertThat(stores).hasSize(2);
+ // block duplicate stores
+ Map<Integer, Long> counts = stores.stream().map(CommandStore::id).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
+ for (Map.Entry<Integer, Long> e : counts.entrySet())
+ {
+ if (e.getValue() == 1) continue;
+ throw new AssertionError("Duplicate command store detected with id: " + e.getKey());
+ }
+ return stores;
+ }
+
+ private static void check(AsyncChain<Void> execute)
+ {
+ try
+ {
+ AsyncChains.getUninterruptibly(execute);
+ }
+ catch (ExecutionException e)
+ {
+ throw new AssertionError(e.getCause());
+ }
+ }
+
+ private static class State
+ {
+ private final Node node;
+ private final MessageSink sink;
+ private final TxnId txnId;
+ private final PartialTxn partialTxn;
+ private final Keys keys;
+ private final Key key;
+ private final FullRoute<?> route;
+ private final PartialRoute<?> partialRoute;
+ private final RoutingKey progressKey;
+ private final Timestamp executeAt;
+ private final PartialDeps deps;
+ private final AsyncResults.SettableResult<Data> readResult;
+
+ State(Node node, MessageSink sink, TxnId txnId, PartialTxn partialTxn, AsyncResults.SettableResult<Data> readResult)
+ {
+ this.node = node;
+ this.sink = sink;
+ this.txnId = txnId;
+ this.partialTxn = partialTxn;
+ this.keys = (Keys) partialTxn.keys();
+ this.key = keys.get(0);
+ this.route = keys.toRoute(key.toUnseekable());
+ this.partialRoute = route.slice(RANGES);
+ this.progressKey = key.toUnseekable();
+ this.executeAt = txnId;
+ this.deps = PartialDeps.builder(RANGES).build();
+ this.readResult = readResult;
+ }
+
+ void readyToExecute(CommandStore store)
+ {
+ check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> {
+ CheckedCommands.preaccept(safe, txnId, partialTxn, route, progressKey);
+ CheckedCommands.accept(safe, txnId, Ballot.ZERO, partialRoute, partialTxn.keys(), progressKey, executeAt, deps);
+ CheckedCommands.commit(safe, txnId, route, progressKey, partialTxn, executeAt, deps);
+ }));
+ }
+
+ void readyToExecute()
+ {
+ forEach(this::readyToExecute);
+ }
+
+ private void forEach(Consumer<CommandStore> fn)
+ {
+ keys.stream().map(node.commandStores()::unsafeForKey).distinct().forEach(fn);
+ }
+
+ AsyncResults.SettableResult<Void> apply()
+ {
+ AsyncResults.SettableResult<Void> writeResult = new AsyncResults.SettableResult<>();
+ Write write = Mockito.mock(Write.class);
+ Mockito.when(write.apply(any(), any(), any(), any())).thenReturn(writeResult);
+ Writes writes = new Writes(executeAt, keys, write);
+
+ forEach(store -> check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> {
+ CheckedCommands.apply(safe, txnId, safe.latestEpoch(), route, executeAt, deps, writes, Mockito.mock(Result.class));
+ })));
+ return writeResult;
+ }
+
+ ReplyContext process()
+ {
+ ReplyContext replyContext = Mockito.mock(ReplyContext.class);
+ ReadData readData = new ReadData(node.id(), TOPOLOGIES, txnId, keys, txnId);
+ readData.process(node, node.id(), replyContext);
+ return replyContext;
+ }
+ }
+}
\ No newline at end of file
diff --git a/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java b/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java
index 98e6217..8ae2b89 100644
--- a/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java
+++ b/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java
@@ -610,6 +610,8 @@
{
if (maybeWrite >= 0)
{
+ if (IntStream.of(sequence).anyMatch(i -> i == maybeWrite))
+ throw new HistoryViolation(key, "Attempted to write " + maybeWrite + " which is already found in the seq; seq=" + Arrays.toString(sequence));
sequence = Arrays.copyOf(sequence, sequence.length + 1);
sequence[sequence.length - 1] = maybeWrite;
}
diff --git a/buildSrc/src/main/groovy/accord.java-conventions.gradle b/buildSrc/src/main/groovy/accord.java-conventions.gradle
index 9e663ee..5817cc9 100644
--- a/buildSrc/src/main/groovy/accord.java-conventions.gradle
+++ b/buildSrc/src/main/groovy/accord.java-conventions.gradle
@@ -47,6 +47,7 @@
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0'
testImplementation group: 'org.assertj', name: 'assertj-core', version: '3.24.2'
+ testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.7.0'
}
task copyMainDependencies(type: Copy) {