RATIS-1913. Assert that the primary peers in DataStreamClient and RoutingTable are equal (#945)
diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
index 7a25a8e..353f532 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
@@ -44,6 +44,7 @@
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
+import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.SlidingWindow;
import java.io.IOException;
@@ -230,6 +231,12 @@
@Override
public DataStreamOutputRpc stream(ByteBuffer headerMessage, RoutingTable routingTable) {
+ if (routingTable != null) {
+ // Validate that the primary peer is equal to the primary peer passed by the RoutingTable
+ Preconditions.assertTrue(dataStreamServer.getId().equals(routingTable.getPrimary()),
+ () -> "Primary peer mismatched: the routing table has " + routingTable.getPrimary()
+ + " but the client has " + dataStreamServer.getId());
+ }
final Message message =
Optional.ofNullable(headerMessage).map(ByteString::copyFrom).map(Message::valueOf).orElse(null);
RaftClientRequest request = RaftClientRequest.newBuilder()
diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
index 0157fe4..56181c4 100644
--- a/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
+++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
@@ -43,6 +43,9 @@
/** @return the successor peers of the given peer. */
Set<RaftPeerId> getSuccessors(RaftPeerId peerId);
+ /** @return the primary peer. */
+ RaftPeerId getPrimary();
+
/** @return the proto of this {@link RoutingTable}. */
RoutingTableProto toProto();
@@ -78,15 +81,15 @@
}
public RoutingTable build() {
- return Optional.ofNullable(ref.getAndSet(null))
- .map(RoutingTable::newRoutingTable)
- .orElseThrow(() -> new IllegalStateException("RoutingTable Already built"));
+ final Map<RaftPeerId, Set<RaftPeerId>> map = ref.getAndSet(null);
+ if (map == null) {
+ throw new IllegalStateException("RoutingTable is already built.");
+ }
+ return RoutingTable.newRoutingTable(map);
}
- static void validate(Map<RaftPeerId, Set<RaftPeerId>> map) {
- if (map != null && !map.isEmpty()) {
- new Builder.Validation(map).run();
- }
+ static RaftPeerId validate(Map<RaftPeerId, Set<RaftPeerId>> map) {
+ return new Builder.Validation(map).run();
}
/** Validate if a map represents a valid routing table. */
@@ -131,10 +134,11 @@
this.unreachablePeers = allPeers;
}
- private void run() {
+ private RaftPeerId run() {
depthFirstSearch(primary);
Preconditions.assertTrue(unreachablePeers.isEmpty() ,
() -> "Invalid routing table: peer(s) " + unreachablePeers + " are unreachable, " + this);
+ return primary;
}
private void depthFirstSearch(RaftPeerId current) {
@@ -159,7 +163,10 @@
/** @return a new {@link RoutingTable} represented by the given map. */
static RoutingTable newRoutingTable(Map<RaftPeerId, Set<RaftPeerId>> map){
- Builder.validate(map);
+ if (map == null || map.isEmpty()) {
+ return null;
+ }
+ final RaftPeerId primary = Builder.validate(map);
final Supplier<RoutingTableProto> proto = JavaUtils.memoize(
() -> RoutingTableProto.newBuilder().addAllRoutes(ProtoUtils.toRouteProtos(map)).build());
@@ -170,6 +177,11 @@
}
@Override
+ public RaftPeerId getPrimary() {
+ return primary;
+ }
+
+ @Override
public RoutingTableProto toProto() {
return proto.get();
}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index fd47045..dcb54be 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -62,6 +62,11 @@
runWithNewCluster(NUM_SERVERS, this::testStreamWrites);
}
+ @Test
+ public void testStreamWithInvalidRoutingTable() throws Exception {
+ runWithNewCluster(NUM_SERVERS, this::runTestInvalidPrimaryInRoutingTable);
+ }
+
void testStreamWrites(CLUSTER cluster) throws Exception {
waitForLeader(cluster);
runTestDataStreamOutput(cluster);
@@ -96,6 +101,31 @@
assertLogEntry(cluster, request);
}
+ void runTestInvalidPrimaryInRoutingTable(CLUSTER cluster) throws Exception {
+ final RaftPeer primaryServer = CollectionUtils.random(cluster.getGroup().getPeers());
+
+ RaftPeer notPrimary = null;
+ for (RaftPeer peer: cluster.getGroup().getPeers()) {
+ if (!peer.equals(primaryServer)) {
+ notPrimary = peer;
+ break;
+ }
+ }
+
+ Assert.assertNotNull(
+ "Cannot find peer other than the primary", notPrimary);
+ Assert.assertNotEquals(primaryServer, notPrimary);
+
+ try (RaftClient client = cluster.createClient(primaryServer)) {
+ RoutingTable routingTableWithWrongPrimary =
+ getRoutingTable(cluster.getGroup().getPeers(), notPrimary);
+ testFailureCase("",
+ () -> client.getDataStreamApi().stream(null,
+ routingTableWithWrongPrimary),
+ IllegalStateException.class);
+ }
+ }
+
void runTestWriteFile(CLUSTER cluster, int i,
CheckedConsumer<DataStreamOutputImpl, Exception> testCase) throws Exception {
final RaftClientRequest request;
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
index 06702f9..8e423ab 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
@@ -62,6 +62,11 @@
@Ignore
@Override
+ public void testStreamWithInvalidRoutingTable() {
+ }
+
+ @Ignore
+ @Override
public void testMultipleStreamsMultipleServers() {
}