blob: 8a6211efd0e37cf70edda1f6cb9a170b9a5d8fd3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package accord.messages;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import accord.Utils;
import accord.api.Data;
import accord.api.Key;
import accord.api.MessageSink;
import accord.api.Query;
import accord.api.Read;
import accord.api.Result;
import accord.api.RoutingKey;
import accord.api.Update;
import accord.api.Write;
import accord.impl.IntKey;
import accord.impl.TopologyFactory;
import accord.impl.mock.MockCluster;
import accord.local.CheckedCommands;
import accord.local.Command;
import accord.local.CommandStore;
import accord.local.Node;
import accord.local.PreLoadContext;
import accord.local.SafeCommand;
import accord.primitives.Ballot;
import accord.primitives.FullRoute;
import accord.primitives.Keys;
import accord.primitives.PartialDeps;
import accord.primitives.PartialRoute;
import accord.primitives.PartialTxn;
import accord.primitives.Range;
import accord.primitives.Ranges;
import accord.primitives.Routable;
import accord.primitives.Timestamp;
import accord.primitives.Txn;
import accord.primitives.TxnId;
import accord.primitives.Writes;
import accord.topology.Topologies;
import accord.topology.Topology;
import accord.utils.async.AsyncChain;
import accord.utils.async.AsyncChains;
import accord.utils.async.AsyncResults;
import org.assertj.core.api.Assertions;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static accord.Utils.createNode;
import static accord.Utils.id;
import static accord.utils.Utils.listOf;
import static org.mockito.ArgumentMatchers.any;
class ReadDataTest
{
private static final Node.Id ID1 = id(1);
private static final Node.Id ID2 = id(2);
private static final Node.Id ID3 = id(3);
private static final List<Node.Id> IDS = listOf(ID1, ID2, ID3);
private static final Range RANGE = IntKey.range(0, 100);
private static final Ranges RANGES = Ranges.single(RANGE);
private static final Topology TOPOLOGY = TopologyFactory.toTopology(IDS, 3, RANGE);
private static final Topologies TOPOLOGIES = Utils.topologies(TOPOLOGY);
private void test(Consumer<State> fn)
{
MessageSink sink = Mockito.mock(MessageSink.class);
Node node = createNode(ID1, TOPOLOGY, sink, new MockCluster.Clock(100));
TxnId txnId = node.nextTxnId(Txn.Kind.Write, Routable.Domain.Key);
Keys keys = Keys.of(IntKey.key(1), IntKey.key(43));
AsyncResults.SettableResult<Data> readResult = new AsyncResults.SettableResult<>();
Read read = Mockito.mock(Read.class);
Mockito.when(read.slice(any())).thenReturn(read);
Mockito.when(read.merge(any())).thenReturn(read);
Mockito.when(read.read(any(), any(), any(), any(), any())).thenAnswer(new Answer<AsyncChain<Data>>()
{
private boolean called = false;
@Override
public AsyncChain<Data> answer(InvocationOnMock ignore) throws Throwable
{
if (called) throw new IllegalStateException("Multiple calls");
return readResult;
}
});
Query query = Mockito.mock(Query.class);
Update update = Mockito.mock(Update.class);
Mockito.when(update.slice(any())).thenReturn(update);
Txn txn = new Txn.InMemory(keys, read, query, update);
PartialTxn partialTxn = txn.slice(RANGES, true);
fn.accept(new State(node, sink, txnId, partialTxn, readResult));
}
@Test
public void readyToExecuteObsoleteFromTracker()
{
// status=ReadyToExecute, so read will happen right away; obsolete marked by ObsoleteTracker
test(state -> {
state.readyToExecute();
ReplyContext replyContext = state.process();
Mockito.verifyNoInteractions(state.sink);
state.apply();
state.readResult.setSuccess(Mockito.mock(Data.class));
Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
});
}
@Test
public void commitObsoleteFromTracker()
{
// status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener
test(state -> {
state.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
CheckedCommands.preaccept(safe, state.txnId, state.partialTxn, state.route, state.progressKey);
CheckedCommands.accept(safe, state.txnId, Ballot.ZERO, state.partialRoute, state.partialTxn.keys(), state.progressKey, state.executeAt, state.deps);
SafeCommand safeCommand = safe.command(state.txnId);
safeCommand.commit(safeCommand.current(), state.executeAt, Command.WaitingOn.EMPTY);
})));
ReplyContext replyContext = state.process();
Mockito.verifyNoInteractions(state.sink);
state.apply();
state.readResult.setSuccess(Mockito.mock(Data.class));
Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
});
}
@Test
public void mapReduceMarksObsolete()
{
// status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener
test(state -> {
List<CommandStore> stores = stores(state);
// this test is a bit implementation specific... so if implementations change this may need an update
// since mapReduceConsume walks the store in id order, by making sure the stores involved in this test
// are in the "right" order, can make sure to hit a very specific edge case
Collections.sort(stores, Comparator.comparingInt(CommandStore::id));
CommandStore store = stores.get(0);
// ack doesn't get called due to waitingOnCount not being -1, can only happen once
// the process command completes
state.readResult.setSuccess(Mockito.mock(Data.class));
state.readyToExecute(store);
store = stores.get(1);
check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
SafeCommand command = safe.command(state.txnId);
command.commitInvalidated(command.current(), state.executeAt);
}));
ReplyContext replyContext = state.process();
Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
});
}
@Test
public void mapReduceAllStageMarksObsolete()
{
test(state -> {
List<CommandStore> stores = stores(state);
stores.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> {
SafeCommand command = safe.command(state.txnId);
command.commitInvalidated(command.current(), state.executeAt);
})));
ReplyContext replyContext = state.process();
Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant));
});
}
private static List<CommandStore> stores(State state)
{
List<CommandStore> stores = new ArrayList<>(2);
state.forEach(stores::add);
Assertions.assertThat(stores).hasSize(2);
// block duplicate stores
Map<Integer, Long> counts = stores.stream().map(CommandStore::id).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
for (Map.Entry<Integer, Long> e : counts.entrySet())
{
if (e.getValue() == 1) continue;
throw new AssertionError("Duplicate command store detected with id: " + e.getKey());
}
return stores;
}
private static void check(AsyncChain<Void> execute)
{
try
{
AsyncChains.getUninterruptibly(execute);
}
catch (ExecutionException e)
{
throw new AssertionError(e.getCause());
}
}
private static class State
{
private final Node node;
private final MessageSink sink;
private final TxnId txnId;
private final PartialTxn partialTxn;
private final Keys keys;
private final Key key;
private final FullRoute<?> route;
private final PartialRoute<?> partialRoute;
private final RoutingKey progressKey;
private final Timestamp executeAt;
private final PartialDeps deps;
private final AsyncResults.SettableResult<Data> readResult;
State(Node node, MessageSink sink, TxnId txnId, PartialTxn partialTxn, AsyncResults.SettableResult<Data> readResult)
{
this.node = node;
this.sink = sink;
this.txnId = txnId;
this.partialTxn = partialTxn;
this.keys = (Keys) partialTxn.keys();
this.key = keys.get(0);
this.route = keys.toRoute(key.toUnseekable());
this.partialRoute = route.slice(RANGES);
this.progressKey = key.toUnseekable();
this.executeAt = txnId;
this.deps = PartialDeps.builder(RANGES).build();
this.readResult = readResult;
}
void readyToExecute(CommandStore store)
{
check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> {
CheckedCommands.preaccept(safe, txnId, partialTxn, route, progressKey);
CheckedCommands.accept(safe, txnId, Ballot.ZERO, partialRoute, partialTxn.keys(), progressKey, executeAt, deps);
CheckedCommands.commit(safe, txnId, route, progressKey, partialTxn, executeAt, deps);
}));
}
void readyToExecute()
{
forEach(this::readyToExecute);
}
private void forEach(Consumer<CommandStore> fn)
{
keys.stream().map(node.commandStores()::unsafeForKey).distinct().forEach(fn);
}
AsyncResults.SettableResult<Void> apply()
{
AsyncResults.SettableResult<Void> writeResult = new AsyncResults.SettableResult<>();
Write write = Mockito.mock(Write.class);
Mockito.when(write.apply(any(), any(), any(), any())).thenReturn(writeResult);
Writes writes = new Writes(executeAt, keys, write);
forEach(store -> check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> {
CheckedCommands.apply(safe, txnId, safe.latestEpoch(), route, executeAt, deps, writes, Mockito.mock(Result.class));
})));
return writeResult;
}
ReplyContext process()
{
ReplyContext replyContext = Mockito.mock(ReplyContext.class);
ReadData readData = new ReadData(node.id(), TOPOLOGIES, txnId, keys, txnId);
readData.process(node, node.id(), replyContext);
return replyContext;
}
}
}